1 /*
2 * Copyright (c) 2018 Gregor Richards
3 * Copyright (c) 2017 Mozilla
4 * Copyright (c) 2005-2009 Xiph.Org Foundation
5 * Copyright (c) 2007-2008 CSIRO
6 * Copyright (c) 2008-2011 Octasic Inc.
7 * Copyright (c) Jean-Marc Valin
8 * Copyright (c) 2019 Paul B Mahol
9 *
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
12 * are met:
13 *
14 * - Redistributions of source code must retain the above copyright
15 * notice, this list of conditions and the following disclaimer.
16 *
17 * - Redistributions in binary form must reproduce the above copyright
18 * notice, this list of conditions and the following disclaimer in the
19 * documentation and/or other materials provided with the distribution.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
25 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #include <float.h>
35
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/opt.h"
40 #include "libavutil/tx.h"
41 #include "avfilter.h"
42 #include "audio.h"
43 #include "filters.h"
44 #include "formats.h"
45
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
50
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
55
56 #define SQUARE(x) ((x)*(x))
57
58 #define NB_BANDS 22
59
60 #define CEPS_MEM 8
61 #define NB_DELTA_CEPS 6
62
63 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
64
65 #define WEIGHTS_SCALE (1.f/256)
66
67 #define MAX_NEURONS 128
68
69 #define ACTIVATION_TANH 0
70 #define ACTIVATION_SIGMOID 1
71 #define ACTIVATION_RELU 2
72
73 #define Q15ONE 1.0f
74
75 typedef struct DenseLayer {
76 const float *bias;
77 const float *input_weights;
78 int nb_inputs;
79 int nb_neurons;
80 int activation;
81 } DenseLayer;
82
83 typedef struct GRULayer {
84 const float *bias;
85 const float *input_weights;
86 const float *recurrent_weights;
87 int nb_inputs;
88 int nb_neurons;
89 int activation;
90 } GRULayer;
91
92 typedef struct RNNModel {
93 int input_dense_size;
94 const DenseLayer *input_dense;
95
96 int vad_gru_size;
97 const GRULayer *vad_gru;
98
99 int noise_gru_size;
100 const GRULayer *noise_gru;
101
102 int denoise_gru_size;
103 const GRULayer *denoise_gru;
104
105 int denoise_output_size;
106 const DenseLayer *denoise_output;
107
108 int vad_output_size;
109 const DenseLayer *vad_output;
110 } RNNModel;
111
112 typedef struct RNNState {
113 float *vad_gru_state;
114 float *noise_gru_state;
115 float *denoise_gru_state;
116 RNNModel *model;
117 } RNNState;
118
119 typedef struct DenoiseState {
120 float analysis_mem[FRAME_SIZE];
121 float cepstral_mem[CEPS_MEM][NB_BANDS];
122 int memid;
123 DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124 float pitch_buf[PITCH_BUF_SIZE];
125 float pitch_enh_buf[PITCH_BUF_SIZE];
126 float last_gain;
127 int last_period;
128 float mem_hp_x[2];
129 float lastg[NB_BANDS];
130 RNNState rnn;
131 AVTXContext *tx, *txi;
132 av_tx_fn tx_fn, txi_fn;
133 } DenoiseState;
134
135 typedef struct AudioRNNContext {
136 const AVClass *class;
137
138 char *model_name;
139
140 int channels;
141 DenoiseState *st;
142
143 DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
144 DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
145
146 RNNModel *model;
147
148 AVFloatDSPContext *fdsp;
149 } AudioRNNContext;
150
151 #define F_ACTIVATION_TANH 0
152 #define F_ACTIVATION_SIGMOID 1
153 #define F_ACTIVATION_RELU 2
154
rnnoise_model_free(RNNModel * model)155 static void rnnoise_model_free(RNNModel *model)
156 {
157 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
158 #define FREE_DENSE(name) do { \
159 if (model->name) { \
160 av_free((void *) model->name->input_weights); \
161 av_free((void *) model->name->bias); \
162 av_free((void *) model->name); \
163 } \
164 } while (0)
165 #define FREE_GRU(name) do { \
166 if (model->name) { \
167 av_free((void *) model->name->input_weights); \
168 av_free((void *) model->name->recurrent_weights); \
169 av_free((void *) model->name->bias); \
170 av_free((void *) model->name); \
171 } \
172 } while (0)
173
174 if (!model)
175 return;
176 FREE_DENSE(input_dense);
177 FREE_GRU(vad_gru);
178 FREE_GRU(noise_gru);
179 FREE_GRU(denoise_gru);
180 FREE_DENSE(denoise_output);
181 FREE_DENSE(vad_output);
182 av_free(model);
183 }
184
rnnoise_model_from_file(FILE * f)185 static RNNModel *rnnoise_model_from_file(FILE *f)
186 {
187 RNNModel *ret;
188 DenseLayer *input_dense;
189 GRULayer *vad_gru;
190 GRULayer *noise_gru;
191 GRULayer *denoise_gru;
192 DenseLayer *denoise_output;
193 DenseLayer *vad_output;
194 int in;
195
196 if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
197 return NULL;
198
199 ret = av_calloc(1, sizeof(RNNModel));
200 if (!ret)
201 return NULL;
202
203 #define ALLOC_LAYER(type, name) \
204 name = av_calloc(1, sizeof(type)); \
205 if (!name) { \
206 rnnoise_model_free(ret); \
207 return NULL; \
208 } \
209 ret->name = name
210
211 ALLOC_LAYER(DenseLayer, input_dense);
212 ALLOC_LAYER(GRULayer, vad_gru);
213 ALLOC_LAYER(GRULayer, noise_gru);
214 ALLOC_LAYER(GRULayer, denoise_gru);
215 ALLOC_LAYER(DenseLayer, denoise_output);
216 ALLOC_LAYER(DenseLayer, vad_output);
217
218 #define INPUT_VAL(name) do { \
219 if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
220 rnnoise_model_free(ret); \
221 return NULL; \
222 } \
223 name = in; \
224 } while (0)
225
226 #define INPUT_ACTIVATION(name) do { \
227 int activation; \
228 INPUT_VAL(activation); \
229 switch (activation) { \
230 case F_ACTIVATION_SIGMOID: \
231 name = ACTIVATION_SIGMOID; \
232 break; \
233 case F_ACTIVATION_RELU: \
234 name = ACTIVATION_RELU; \
235 break; \
236 default: \
237 name = ACTIVATION_TANH; \
238 } \
239 } while (0)
240
241 #define INPUT_ARRAY(name, len) do { \
242 float *values = av_calloc((len), sizeof(float)); \
243 if (!values) { \
244 rnnoise_model_free(ret); \
245 return NULL; \
246 } \
247 name = values; \
248 for (int i = 0; i < (len); i++) { \
249 if (fscanf(f, "%d", &in) != 1) { \
250 rnnoise_model_free(ret); \
251 return NULL; \
252 } \
253 values[i] = in; \
254 } \
255 } while (0)
256
257 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
258 float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
259 if (!values) { \
260 rnnoise_model_free(ret); \
261 return NULL; \
262 } \
263 name = values; \
264 for (int k = 0; k < (len0); k++) { \
265 for (int i = 0; i < (len2); i++) { \
266 for (int j = 0; j < (len1); j++) { \
267 if (fscanf(f, "%d", &in) != 1) { \
268 rnnoise_model_free(ret); \
269 return NULL; \
270 } \
271 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
272 } \
273 } \
274 } \
275 } while (0)
276
277 #define INPUT_DENSE(name) do { \
278 INPUT_VAL(name->nb_inputs); \
279 INPUT_VAL(name->nb_neurons); \
280 ret->name ## _size = name->nb_neurons; \
281 INPUT_ACTIVATION(name->activation); \
282 INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
283 INPUT_ARRAY(name->bias, name->nb_neurons); \
284 } while (0)
285
286 #define INPUT_GRU(name) do { \
287 INPUT_VAL(name->nb_inputs); \
288 INPUT_VAL(name->nb_neurons); \
289 ret->name ## _size = name->nb_neurons; \
290 INPUT_ACTIVATION(name->activation); \
291 INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
292 INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
293 INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
294 } while (0)
295
296 INPUT_DENSE(input_dense);
297 INPUT_GRU(vad_gru);
298 INPUT_GRU(noise_gru);
299 INPUT_GRU(denoise_gru);
300 INPUT_DENSE(denoise_output);
301 INPUT_DENSE(vad_output);
302
303 if (vad_output->nb_neurons != 1) {
304 rnnoise_model_free(ret);
305 return NULL;
306 }
307
308 return ret;
309 }
310
query_formats(AVFilterContext * ctx)311 static int query_formats(AVFilterContext *ctx)
312 {
313 AVFilterFormats *formats = NULL;
314 AVFilterChannelLayouts *layouts = NULL;
315 static const enum AVSampleFormat sample_fmts[] = {
316 AV_SAMPLE_FMT_FLTP,
317 AV_SAMPLE_FMT_NONE
318 };
319 int ret, sample_rates[] = { 48000, -1 };
320
321 formats = ff_make_format_list(sample_fmts);
322 if (!formats)
323 return AVERROR(ENOMEM);
324 ret = ff_set_common_formats(ctx, formats);
325 if (ret < 0)
326 return ret;
327
328 layouts = ff_all_channel_counts();
329 if (!layouts)
330 return AVERROR(ENOMEM);
331
332 ret = ff_set_common_channel_layouts(ctx, layouts);
333 if (ret < 0)
334 return ret;
335
336 formats = ff_make_format_list(sample_rates);
337 if (!formats)
338 return AVERROR(ENOMEM);
339 return ff_set_common_samplerates(ctx, formats);
340 }
341
config_input(AVFilterLink * inlink)342 static int config_input(AVFilterLink *inlink)
343 {
344 AVFilterContext *ctx = inlink->dst;
345 AudioRNNContext *s = ctx->priv;
346 int ret;
347
348 s->channels = inlink->channels;
349
350 s->st = av_calloc(s->channels, sizeof(DenoiseState));
351 if (!s->st)
352 return AVERROR(ENOMEM);
353
354 for (int i = 0; i < s->channels; i++) {
355 DenoiseState *st = &s->st[i];
356
357 st->rnn.model = s->model;
358 st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
359 st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
360 st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
361 if (!st->rnn.vad_gru_state ||
362 !st->rnn.noise_gru_state ||
363 !st->rnn.denoise_gru_state)
364 return AVERROR(ENOMEM);
365
366 ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
367 if (ret < 0)
368 return ret;
369
370 ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
371 if (ret < 0)
372 return ret;
373 }
374
375 return 0;
376 }
377
biquad(float * y,float mem[2],const float * x,const float * b,const float * a,int N)378 static void biquad(float *y, float mem[2], const float *x,
379 const float *b, const float *a, int N)
380 {
381 for (int i = 0; i < N; i++) {
382 float xi, yi;
383
384 xi = x[i];
385 yi = x[i] + mem[0];
386 mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
387 mem[1] = (b[1]*xi - a[1]*yi);
388 y[i] = yi;
389 }
390 }
391
392 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
393 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
394 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
395
forward_transform(DenoiseState * st,AVComplexFloat * out,const float * in)396 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
397 {
398 AVComplexFloat x[WINDOW_SIZE];
399 AVComplexFloat y[WINDOW_SIZE];
400
401 for (int i = 0; i < WINDOW_SIZE; i++) {
402 x[i].re = in[i];
403 x[i].im = 0;
404 }
405
406 st->tx_fn(st->tx, y, x, sizeof(float));
407
408 RNN_COPY(out, y, FREQ_SIZE);
409 }
410
inverse_transform(DenoiseState * st,float * out,const AVComplexFloat * in)411 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
412 {
413 AVComplexFloat x[WINDOW_SIZE];
414 AVComplexFloat y[WINDOW_SIZE];
415
416 RNN_COPY(x, in, FREQ_SIZE);
417
418 for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
419 x[i].re = x[WINDOW_SIZE - i].re;
420 x[i].im = -x[WINDOW_SIZE - i].im;
421 }
422
423 st->txi_fn(st->txi, y, x, sizeof(float));
424
425 for (int i = 0; i < WINDOW_SIZE; i++)
426 out[i] = y[i].re / WINDOW_SIZE;
427 }
428
429 static const uint8_t eband5ms[] = {
430 /*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/
431 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
432 };
433
compute_band_energy(float * bandE,const AVComplexFloat * X)434 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
435 {
436 float sum[NB_BANDS] = {0};
437
438 for (int i = 0; i < NB_BANDS - 1; i++) {
439 int band_size;
440
441 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
442 for (int j = 0; j < band_size; j++) {
443 float tmp, frac = (float)j / band_size;
444
445 tmp = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
446 tmp += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
447 sum[i] += (1.f - frac) * tmp;
448 sum[i + 1] += frac * tmp;
449 }
450 }
451
452 sum[0] *= 2;
453 sum[NB_BANDS - 1] *= 2;
454
455 for (int i = 0; i < NB_BANDS; i++)
456 bandE[i] = sum[i];
457 }
458
compute_band_corr(float * bandE,const AVComplexFloat * X,const AVComplexFloat * P)459 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
460 {
461 float sum[NB_BANDS] = { 0 };
462
463 for (int i = 0; i < NB_BANDS - 1; i++) {
464 int band_size;
465
466 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
467 for (int j = 0; j < band_size; j++) {
468 float tmp, frac = (float)j / band_size;
469
470 tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
471 tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
472 sum[i] += (1 - frac) * tmp;
473 sum[i + 1] += frac * tmp;
474 }
475 }
476
477 sum[0] *= 2;
478 sum[NB_BANDS-1] *= 2;
479
480 for (int i = 0; i < NB_BANDS; i++)
481 bandE[i] = sum[i];
482 }
483
frame_analysis(AudioRNNContext * s,DenoiseState * st,AVComplexFloat * X,float * Ex,const float * in)484 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
485 {
486 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
487
488 RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
489 RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
490 RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
491 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
492 forward_transform(st, X, x);
493 compute_band_energy(Ex, X);
494 }
495
frame_synthesis(AudioRNNContext * s,DenoiseState * st,float * out,const AVComplexFloat * y)496 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
497 {
498 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
499
500 inverse_transform(st, x, y);
501 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
502 s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
503 RNN_COPY(out, x, FRAME_SIZE);
504 RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
505 }
506
xcorr_kernel(const float * x,const float * y,float sum[4],int len)507 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
508 {
509 float y_0, y_1, y_2, y_3 = 0;
510 int j;
511
512 y_0 = *y++;
513 y_1 = *y++;
514 y_2 = *y++;
515
516 for (j = 0; j < len - 3; j += 4) {
517 float tmp;
518
519 tmp = *x++;
520 y_3 = *y++;
521 sum[0] += tmp * y_0;
522 sum[1] += tmp * y_1;
523 sum[2] += tmp * y_2;
524 sum[3] += tmp * y_3;
525 tmp = *x++;
526 y_0 = *y++;
527 sum[0] += tmp * y_1;
528 sum[1] += tmp * y_2;
529 sum[2] += tmp * y_3;
530 sum[3] += tmp * y_0;
531 tmp = *x++;
532 y_1 = *y++;
533 sum[0] += tmp * y_2;
534 sum[1] += tmp * y_3;
535 sum[2] += tmp * y_0;
536 sum[3] += tmp * y_1;
537 tmp = *x++;
538 y_2 = *y++;
539 sum[0] += tmp * y_3;
540 sum[1] += tmp * y_0;
541 sum[2] += tmp * y_1;
542 sum[3] += tmp * y_2;
543 }
544
545 if (j++ < len) {
546 float tmp = *x++;
547
548 y_3 = *y++;
549 sum[0] += tmp * y_0;
550 sum[1] += tmp * y_1;
551 sum[2] += tmp * y_2;
552 sum[3] += tmp * y_3;
553 }
554
555 if (j++ < len) {
556 float tmp=*x++;
557
558 y_0 = *y++;
559 sum[0] += tmp * y_1;
560 sum[1] += tmp * y_2;
561 sum[2] += tmp * y_3;
562 sum[3] += tmp * y_0;
563 }
564
565 if (j < len) {
566 float tmp=*x++;
567
568 y_1 = *y++;
569 sum[0] += tmp * y_2;
570 sum[1] += tmp * y_3;
571 sum[2] += tmp * y_0;
572 sum[3] += tmp * y_1;
573 }
574 }
575
celt_inner_prod(const float * x,const float * y,int N)576 static inline float celt_inner_prod(const float *x,
577 const float *y, int N)
578 {
579 float xy = 0.f;
580
581 for (int i = 0; i < N; i++)
582 xy += x[i] * y[i];
583
584 return xy;
585 }
586
celt_pitch_xcorr(const float * x,const float * y,float * xcorr,int len,int max_pitch)587 static void celt_pitch_xcorr(const float *x, const float *y,
588 float *xcorr, int len, int max_pitch)
589 {
590 int i;
591
592 for (i = 0; i < max_pitch - 3; i += 4) {
593 float sum[4] = { 0, 0, 0, 0};
594
595 xcorr_kernel(x, y + i, sum, len);
596
597 xcorr[i] = sum[0];
598 xcorr[i + 1] = sum[1];
599 xcorr[i + 2] = sum[2];
600 xcorr[i + 3] = sum[3];
601 }
602 /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
603 for (; i < max_pitch; i++) {
604 xcorr[i] = celt_inner_prod(x, y + i, len);
605 }
606 }
607
celt_autocorr(const float * x,float * ac,const float * window,int overlap,int lag,int n)608 static int celt_autocorr(const float *x, /* in: [0...n-1] samples x */
609 float *ac, /* out: [0...lag-1] ac values */
610 const float *window,
611 int overlap,
612 int lag,
613 int n)
614 {
615 int fastN = n - lag;
616 int shift;
617 const float *xptr;
618 float xx[PITCH_BUF_SIZE>>1];
619
620 if (overlap == 0) {
621 xptr = x;
622 } else {
623 for (int i = 0; i < n; i++)
624 xx[i] = x[i];
625 for (int i = 0; i < overlap; i++) {
626 xx[i] = x[i] * window[i];
627 xx[n-i-1] = x[n-i-1] * window[i];
628 }
629 xptr = xx;
630 }
631
632 shift = 0;
633 celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
634
635 for (int k = 0; k <= lag; k++) {
636 float d = 0.f;
637
638 for (int i = k + fastN; i < n; i++)
639 d += xptr[i] * xptr[i-k];
640 ac[k] += d;
641 }
642
643 return shift;
644 }
645
celt_lpc(float * lpc,const float * ac,int p)646 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients */
647 const float *ac, /* in: [0...p] autocorrelation values */
648 int p)
649 {
650 float r, error = ac[0];
651
652 RNN_CLEAR(lpc, p);
653 if (ac[0] != 0) {
654 for (int i = 0; i < p; i++) {
655 /* Sum up this iteration's reflection coefficient */
656 float rr = 0;
657 for (int j = 0; j < i; j++)
658 rr += (lpc[j] * ac[i - j]);
659 rr += ac[i + 1];
660 r = -rr/error;
661 /* Update LPC coefficients and total error */
662 lpc[i] = r;
663 for (int j = 0; j < (i + 1) >> 1; j++) {
664 float tmp1, tmp2;
665 tmp1 = lpc[j];
666 tmp2 = lpc[i-1-j];
667 lpc[j] = tmp1 + (r*tmp2);
668 lpc[i-1-j] = tmp2 + (r*tmp1);
669 }
670
671 error = error - (r * r *error);
672 /* Bail out once we get 30 dB gain */
673 if (error < .001f * ac[0])
674 break;
675 }
676 }
677 }
678
celt_fir5(const float * x,const float * num,float * y,int N,float * mem)679 static void celt_fir5(const float *x,
680 const float *num,
681 float *y,
682 int N,
683 float *mem)
684 {
685 float num0, num1, num2, num3, num4;
686 float mem0, mem1, mem2, mem3, mem4;
687
688 num0 = num[0];
689 num1 = num[1];
690 num2 = num[2];
691 num3 = num[3];
692 num4 = num[4];
693 mem0 = mem[0];
694 mem1 = mem[1];
695 mem2 = mem[2];
696 mem3 = mem[3];
697 mem4 = mem[4];
698
699 for (int i = 0; i < N; i++) {
700 float sum = x[i];
701
702 sum += (num0*mem0);
703 sum += (num1*mem1);
704 sum += (num2*mem2);
705 sum += (num3*mem3);
706 sum += (num4*mem4);
707 mem4 = mem3;
708 mem3 = mem2;
709 mem2 = mem1;
710 mem1 = mem0;
711 mem0 = x[i];
712 y[i] = sum;
713 }
714
715 mem[0] = mem0;
716 mem[1] = mem1;
717 mem[2] = mem2;
718 mem[3] = mem3;
719 mem[4] = mem4;
720 }
721
pitch_downsample(float * x[],float * x_lp,int len,int C)722 static void pitch_downsample(float *x[], float *x_lp,
723 int len, int C)
724 {
725 float ac[5];
726 float tmp=Q15ONE;
727 float lpc[4], mem[5]={0,0,0,0,0};
728 float lpc2[5];
729 float c1 = .8f;
730
731 for (int i = 1; i < len >> 1; i++)
732 x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
733 x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
734 if (C==2) {
735 for (int i = 1; i < len >> 1; i++)
736 x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
737 x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
738 }
739
740 celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
741
742 /* Noise floor -40 dB */
743 ac[0] *= 1.0001f;
744 /* Lag windowing */
745 for (int i = 1; i <= 4; i++) {
746 /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
747 ac[i] -= ac[i]*(.008f*i)*(.008f*i);
748 }
749
750 celt_lpc(lpc, ac, 4);
751 for (int i = 0; i < 4; i++) {
752 tmp = .9f * tmp;
753 lpc[i] = (lpc[i] * tmp);
754 }
755 /* Add a zero */
756 lpc2[0] = lpc[0] + .8f;
757 lpc2[1] = lpc[1] + (c1 * lpc[0]);
758 lpc2[2] = lpc[2] + (c1 * lpc[1]);
759 lpc2[3] = lpc[3] + (c1 * lpc[2]);
760 lpc2[4] = (c1 * lpc[3]);
761 celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
762 }
763
dual_inner_prod(const float * x,const float * y01,const float * y02,int N,float * xy1,float * xy2)764 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
765 int N, float *xy1, float *xy2)
766 {
767 float xy01 = 0, xy02 = 0;
768
769 for (int i = 0; i < N; i++) {
770 xy01 += (x[i] * y01[i]);
771 xy02 += (x[i] * y02[i]);
772 }
773
774 *xy1 = xy01;
775 *xy2 = xy02;
776 }
777
compute_pitch_gain(float xy,float xx,float yy)778 static float compute_pitch_gain(float xy, float xx, float yy)
779 {
780 return xy / sqrtf(1.f + xx * yy);
781 }
782
783 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
remove_doubling(float * x,int maxperiod,int minperiod,int N,int * T0_,int prev_period,float prev_gain)784 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
785 int *T0_, int prev_period, float prev_gain)
786 {
787 int k, i, T, T0;
788 float g, g0;
789 float pg;
790 float xy,xx,yy,xy2;
791 float xcorr[3];
792 float best_xy, best_yy;
793 int offset;
794 int minperiod0;
795 float yy_lookup[PITCH_MAX_PERIOD+1];
796
797 minperiod0 = minperiod;
798 maxperiod /= 2;
799 minperiod /= 2;
800 *T0_ /= 2;
801 prev_period /= 2;
802 N /= 2;
803 x += maxperiod;
804 if (*T0_>=maxperiod)
805 *T0_=maxperiod-1;
806
807 T = T0 = *T0_;
808 dual_inner_prod(x, x, x-T0, N, &xx, &xy);
809 yy_lookup[0] = xx;
810 yy=xx;
811 for (i = 1; i <= maxperiod; i++) {
812 yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
813 yy_lookup[i] = FFMAX(0, yy);
814 }
815 yy = yy_lookup[T0];
816 best_xy = xy;
817 best_yy = yy;
818 g = g0 = compute_pitch_gain(xy, xx, yy);
819 /* Look for any pitch at T/k */
820 for (k = 2; k <= 15; k++) {
821 int T1, T1b;
822 float g1;
823 float cont=0;
824 float thresh;
825 T1 = (2*T0+k)/(2*k);
826 if (T1 < minperiod)
827 break;
828 /* Look for another strong correlation at T1b */
829 if (k==2)
830 {
831 if (T1+T0>maxperiod)
832 T1b = T0;
833 else
834 T1b = T0+T1;
835 } else
836 {
837 T1b = (2*second_check[k]*T0+k)/(2*k);
838 }
839 dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
840 xy = .5f * (xy + xy2);
841 yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
842 g1 = compute_pitch_gain(xy, xx, yy);
843 if (FFABS(T1-prev_period)<=1)
844 cont = prev_gain;
845 else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
846 cont = prev_gain * .5f;
847 else
848 cont = 0;
849 thresh = FFMAX(.3f, (.7f * g0) - cont);
850 /* Bias against very high pitch (very short period) to avoid false-positives
851 due to short-term correlation */
852 if (T1<3*minperiod)
853 thresh = FFMAX(.4f, (.85f * g0) - cont);
854 else if (T1<2*minperiod)
855 thresh = FFMAX(.5f, (.9f * g0) - cont);
856 if (g1 > thresh)
857 {
858 best_xy = xy;
859 best_yy = yy;
860 T = T1;
861 g = g1;
862 }
863 }
864 best_xy = FFMAX(0, best_xy);
865 if (best_yy <= best_xy)
866 pg = Q15ONE;
867 else
868 pg = best_xy/(best_yy + 1);
869
870 for (k = 0; k < 3; k++)
871 xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
872 if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
873 offset = 1;
874 else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
875 offset = -1;
876 else
877 offset = 0;
878 if (pg > g)
879 pg = g;
880 *T0_ = 2*T+offset;
881
882 if (*T0_<minperiod0)
883 *T0_=minperiod0;
884 return pg;
885 }
886
find_best_pitch(float * xcorr,float * y,int len,int max_pitch,int * best_pitch)887 static void find_best_pitch(float *xcorr, float *y, int len,
888 int max_pitch, int *best_pitch)
889 {
890 float best_num[2];
891 float best_den[2];
892 float Syy = 1.f;
893
894 best_num[0] = -1;
895 best_num[1] = -1;
896 best_den[0] = 0;
897 best_den[1] = 0;
898 best_pitch[0] = 0;
899 best_pitch[1] = 1;
900
901 for (int j = 0; j < len; j++)
902 Syy += y[j] * y[j];
903
904 for (int i = 0; i < max_pitch; i++) {
905 if (xcorr[i]>0) {
906 float num;
907 float xcorr16;
908
909 xcorr16 = xcorr[i];
910 /* Considering the range of xcorr16, this should avoid both underflows
911 and overflows (inf) when squaring xcorr16 */
912 xcorr16 *= 1e-12f;
913 num = xcorr16 * xcorr16;
914 if ((num * best_den[1]) > (best_num[1] * Syy)) {
915 if ((num * best_den[0]) > (best_num[0] * Syy)) {
916 best_num[1] = best_num[0];
917 best_den[1] = best_den[0];
918 best_pitch[1] = best_pitch[0];
919 best_num[0] = num;
920 best_den[0] = Syy;
921 best_pitch[0] = i;
922 } else {
923 best_num[1] = num;
924 best_den[1] = Syy;
925 best_pitch[1] = i;
926 }
927 }
928 }
929 Syy += y[i+len]*y[i+len] - y[i] * y[i];
930 Syy = FFMAX(1, Syy);
931 }
932 }
933
pitch_search(const float * x_lp,float * y,int len,int max_pitch,int * pitch)934 static void pitch_search(const float *x_lp, float *y,
935 int len, int max_pitch, int *pitch)
936 {
937 int lag;
938 int best_pitch[2]={0,0};
939 int offset;
940
941 float x_lp4[WINDOW_SIZE];
942 float y_lp4[WINDOW_SIZE];
943 float xcorr[WINDOW_SIZE];
944
945 lag = len+max_pitch;
946
947 /* Downsample by 2 again */
948 for (int j = 0; j < len >> 2; j++)
949 x_lp4[j] = x_lp[2*j];
950 for (int j = 0; j < lag >> 2; j++)
951 y_lp4[j] = y[2*j];
952
953 /* Coarse search with 4x decimation */
954
955 celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
956
957 find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
958
959 /* Finer search with 2x decimation */
960 for (int i = 0; i < max_pitch >> 1; i++) {
961 float sum;
962 xcorr[i] = 0;
963 if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
964 continue;
965 sum = celt_inner_prod(x_lp, y+i, len>>1);
966 xcorr[i] = FFMAX(-1, sum);
967 }
968
969 find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
970
971 /* Refine by pseudo-interpolation */
972 if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
973 float a, b, c;
974
975 a = xcorr[best_pitch[0] - 1];
976 b = xcorr[best_pitch[0]];
977 c = xcorr[best_pitch[0] + 1];
978 if (c - a > .7f * (b - a))
979 offset = 1;
980 else if (a - c > .7f * (b-c))
981 offset = -1;
982 else
983 offset = 0;
984 } else {
985 offset = 0;
986 }
987
988 *pitch = 2 * best_pitch[0] - offset;
989 }
990
dct(AudioRNNContext * s,float * out,const float * in)991 static void dct(AudioRNNContext *s, float *out, const float *in)
992 {
993 for (int i = 0; i < NB_BANDS; i++) {
994 float sum;
995
996 sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
997 out[i] = sum * sqrtf(2.f / 22);
998 }
999 }
1000
compute_frame_features(AudioRNNContext * s,DenoiseState * st,AVComplexFloat * X,AVComplexFloat * P,float * Ex,float * Ep,float * Exp,float * features,const float * in)1001 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1002 float *Ex, float *Ep, float *Exp, float *features, const float *in)
1003 {
1004 float E = 0;
1005 float *ceps_0, *ceps_1, *ceps_2;
1006 float spec_variability = 0;
1007 LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1008 LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1009 float pitch_buf[PITCH_BUF_SIZE>>1];
1010 int pitch_index;
1011 float gain;
1012 float *(pre[1]);
1013 float tmp[NB_BANDS];
1014 float follow, logMax;
1015
1016 frame_analysis(s, st, X, Ex, in);
1017 RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1018 RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1019 pre[0] = &st->pitch_buf[0];
1020 pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1021 pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1022 PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1023 pitch_index = PITCH_MAX_PERIOD-pitch_index;
1024
1025 gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1026 PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1027 st->last_period = pitch_index;
1028 st->last_gain = gain;
1029
1030 for (int i = 0; i < WINDOW_SIZE; i++)
1031 p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1032
1033 s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1034 forward_transform(st, P, p);
1035 compute_band_energy(Ep, P);
1036 compute_band_corr(Exp, X, P);
1037
1038 for (int i = 0; i < NB_BANDS; i++)
1039 Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1040
1041 dct(s, tmp, Exp);
1042
1043 for (int i = 0; i < NB_DELTA_CEPS; i++)
1044 features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1045
1046 features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1047 features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1048 features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1049 logMax = -2;
1050 follow = -2;
1051
1052 for (int i = 0; i < NB_BANDS; i++) {
1053 Ly[i] = log10f(1e-2f + Ex[i]);
1054 Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1055 logMax = FFMAX(logMax, Ly[i]);
1056 follow = FFMAX(follow-1.5, Ly[i]);
1057 E += Ex[i];
1058 }
1059
1060 if (E < 0.04f) {
1061 /* If there's no audio, avoid messing up the state. */
1062 RNN_CLEAR(features, NB_FEATURES);
1063 return 1;
1064 }
1065
1066 dct(s, features, Ly);
1067 features[0] -= 12;
1068 features[1] -= 4;
1069 ceps_0 = st->cepstral_mem[st->memid];
1070 ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1071 ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1072
1073 for (int i = 0; i < NB_BANDS; i++)
1074 ceps_0[i] = features[i];
1075
1076 st->memid++;
1077 for (int i = 0; i < NB_DELTA_CEPS; i++) {
1078 features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1079 features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1080 features[NB_BANDS+NB_DELTA_CEPS+i] = ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1081 }
1082 /* Spectral variability features. */
1083 if (st->memid == CEPS_MEM)
1084 st->memid = 0;
1085
1086 for (int i = 0; i < CEPS_MEM; i++) {
1087 float mindist = 1e15f;
1088 for (int j = 0; j < CEPS_MEM; j++) {
1089 float dist = 0.f;
1090 for (int k = 0; k < NB_BANDS; k++) {
1091 float tmp;
1092
1093 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1094 dist += tmp*tmp;
1095 }
1096
1097 if (j != i)
1098 mindist = FFMIN(mindist, dist);
1099 }
1100
1101 spec_variability += mindist;
1102 }
1103
1104 features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1105
1106 return 0;
1107 }
1108
interp_band_gain(float * g,const float * bandE)1109 static void interp_band_gain(float *g, const float *bandE)
1110 {
1111 memset(g, 0, sizeof(*g) * FREQ_SIZE);
1112
1113 for (int i = 0; i < NB_BANDS - 1; i++) {
1114 const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1115
1116 for (int j = 0; j < band_size; j++) {
1117 float frac = (float)j / band_size;
1118
1119 g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1120 }
1121 }
1122 }
1123
pitch_filter(AVComplexFloat * X,const AVComplexFloat * P,const float * Ex,const float * Ep,const float * Exp,const float * g)1124 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1125 const float *Exp, const float *g)
1126 {
1127 float newE[NB_BANDS];
1128 float r[NB_BANDS];
1129 float norm[NB_BANDS];
1130 float rf[FREQ_SIZE] = {0};
1131 float normf[FREQ_SIZE]={0};
1132
1133 for (int i = 0; i < NB_BANDS; i++) {
1134 if (Exp[i]>g[i]) r[i] = 1;
1135 else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1136 r[i] = sqrtf(av_clipf(r[i], 0, 1));
1137 r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1138 }
1139 interp_band_gain(rf, r);
1140 for (int i = 0; i < FREQ_SIZE; i++) {
1141 X[i].re += rf[i]*P[i].re;
1142 X[i].im += rf[i]*P[i].im;
1143 }
1144 compute_band_energy(newE, X);
1145 for (int i = 0; i < NB_BANDS; i++) {
1146 norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1147 }
1148 interp_band_gain(normf, norm);
1149 for (int i = 0; i < FREQ_SIZE; i++) {
1150 X[i].re *= normf[i];
1151 X[i].im *= normf[i];
1152 }
1153 }
1154
1155 static const float tansig_table[201] = {
1156 0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1157 0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1158 0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1159 0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1160 0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1161 0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1162 0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1163 0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1164 0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1165 0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1166 0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1167 0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1168 0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1169 0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1170 0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1171 0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1172 0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1173 0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1174 0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1175 0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1176 0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1177 0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1178 0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1179 0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1180 0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1181 0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1182 0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1183 0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1184 0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1185 0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1186 0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1187 0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1188 0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1189 0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1190 0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1191 0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1192 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1193 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1194 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1195 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1196 1.000000f,
1197 };
1198
tansig_approx(float x)1199 static inline float tansig_approx(float x)
1200 {
1201 float y, dy;
1202 float sign=1;
1203 int i;
1204
1205 /* Tests are reversed to catch NaNs */
1206 if (!(x<8))
1207 return 1;
1208 if (!(x>-8))
1209 return -1;
1210 /* Another check in case of -ffast-math */
1211
1212 if (isnan(x))
1213 return 0;
1214
1215 if (x < 0) {
1216 x=-x;
1217 sign=-1;
1218 }
1219 i = (int)floor(.5f+25*x);
1220 x -= .04f*i;
1221 y = tansig_table[i];
1222 dy = 1-y*y;
1223 y = y + x*dy*(1 - y*x);
1224 return sign*y;
1225 }
1226
sigmoid_approx(float x)1227 static inline float sigmoid_approx(float x)
1228 {
1229 return .5f + .5f*tansig_approx(.5f*x);
1230 }
1231
compute_dense(const DenseLayer * layer,float * output,const float * input)1232 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1233 {
1234 const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1235
1236 for (int i = 0; i < N; i++) {
1237 /* Compute update gate. */
1238 float sum = layer->bias[i];
1239
1240 for (int j = 0; j < M; j++)
1241 sum += layer->input_weights[j * stride + i] * input[j];
1242
1243 output[i] = WEIGHTS_SCALE * sum;
1244 }
1245
1246 if (layer->activation == ACTIVATION_SIGMOID) {
1247 for (int i = 0; i < N; i++)
1248 output[i] = sigmoid_approx(output[i]);
1249 } else if (layer->activation == ACTIVATION_TANH) {
1250 for (int i = 0; i < N; i++)
1251 output[i] = tansig_approx(output[i]);
1252 } else if (layer->activation == ACTIVATION_RELU) {
1253 for (int i = 0; i < N; i++)
1254 output[i] = FFMAX(0, output[i]);
1255 } else {
1256 av_assert0(0);
1257 }
1258 }
1259
compute_gru(AudioRNNContext * s,const GRULayer * gru,float * state,const float * input)1260 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1261 {
1262 LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1263 LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1264 LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1265 const int M = gru->nb_inputs;
1266 const int N = gru->nb_neurons;
1267 const int AN = FFALIGN(N, 4);
1268 const int AM = FFALIGN(M, 4);
1269 const int stride = 3 * AN, istride = 3 * AM;
1270
1271 for (int i = 0; i < N; i++) {
1272 /* Compute update gate. */
1273 float sum = gru->bias[i];
1274
1275 sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1276 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1277 z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1278 }
1279
1280 for (int i = 0; i < N; i++) {
1281 /* Compute reset gate. */
1282 float sum = gru->bias[N + i];
1283
1284 sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1285 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1286 r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1287 }
1288
1289 for (int i = 0; i < N; i++) {
1290 /* Compute output. */
1291 float sum = gru->bias[2 * N + i];
1292
1293 sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1294 for (int j = 0; j < N; j++)
1295 sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1296
1297 if (gru->activation == ACTIVATION_SIGMOID)
1298 sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1299 else if (gru->activation == ACTIVATION_TANH)
1300 sum = tansig_approx(WEIGHTS_SCALE * sum);
1301 else if (gru->activation == ACTIVATION_RELU)
1302 sum = FFMAX(0, WEIGHTS_SCALE * sum);
1303 else
1304 av_assert0(0);
1305 h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1306 }
1307
1308 RNN_COPY(state, h, N);
1309 }
1310
1311 #define INPUT_SIZE 42
1312
compute_rnn(AudioRNNContext * s,RNNState * rnn,float * gains,float * vad,const float * input)1313 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1314 {
1315 LOCAL_ALIGNED_32(float, dense_out, [MAX_NEURONS]);
1316 LOCAL_ALIGNED_32(float, noise_input, [MAX_NEURONS * 3]);
1317 LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1318
1319 compute_dense(rnn->model->input_dense, dense_out, input);
1320 compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1321 compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1322
1323 memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1324 memcpy(noise_input + rnn->model->input_dense_size,
1325 rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1326 memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1327 input, INPUT_SIZE * sizeof(float));
1328
1329 compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1330
1331 memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1332 memcpy(denoise_input + rnn->model->vad_gru_size,
1333 rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1334 memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1335 input, INPUT_SIZE * sizeof(float));
1336
1337 compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1338 compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1339 }
1340
rnnoise_channel(AudioRNNContext * s,DenoiseState * st,float * out,const float * in)1341 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in)
1342 {
1343 AVComplexFloat X[FREQ_SIZE];
1344 AVComplexFloat P[WINDOW_SIZE];
1345 float x[FRAME_SIZE];
1346 float Ex[NB_BANDS], Ep[NB_BANDS];
1347 LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1348 float features[NB_FEATURES];
1349 float g[NB_BANDS];
1350 float gf[FREQ_SIZE];
1351 float vad_prob = 0;
1352 static const float a_hp[2] = {-1.99599, 0.99600};
1353 static const float b_hp[2] = {-2, 1};
1354 int silence;
1355
1356 biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1357 silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1358
1359 if (!silence) {
1360 compute_rnn(s, &st->rnn, g, &vad_prob, features);
1361 pitch_filter(X, P, Ex, Ep, Exp, g);
1362 for (int i = 0; i < NB_BANDS; i++) {
1363 float alpha = .6f;
1364
1365 g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1366 st->lastg[i] = g[i];
1367 }
1368
1369 interp_band_gain(gf, g);
1370
1371 for (int i = 0; i < FREQ_SIZE; i++) {
1372 X[i].re *= gf[i];
1373 X[i].im *= gf[i];
1374 }
1375 }
1376
1377 frame_synthesis(s, st, out, X);
1378
1379 return vad_prob;
1380 }
1381
1382 typedef struct ThreadData {
1383 AVFrame *in, *out;
1384 } ThreadData;
1385
rnnoise_channels(AVFilterContext * ctx,void * arg,int jobnr,int nb_jobs)1386 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1387 {
1388 AudioRNNContext *s = ctx->priv;
1389 ThreadData *td = arg;
1390 AVFrame *in = td->in;
1391 AVFrame *out = td->out;
1392 const int start = (out->channels * jobnr) / nb_jobs;
1393 const int end = (out->channels * (jobnr+1)) / nb_jobs;
1394
1395 for (int ch = start; ch < end; ch++) {
1396 rnnoise_channel(s, &s->st[ch],
1397 (float *)out->extended_data[ch],
1398 (const float *)in->extended_data[ch]);
1399 }
1400
1401 return 0;
1402 }
1403
filter_frame(AVFilterLink * inlink,AVFrame * in)1404 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1405 {
1406 AVFilterContext *ctx = inlink->dst;
1407 AVFilterLink *outlink = ctx->outputs[0];
1408 AVFrame *out = NULL;
1409 ThreadData td;
1410
1411 out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1412 if (!out) {
1413 av_frame_free(&in);
1414 return AVERROR(ENOMEM);
1415 }
1416 out->pts = in->pts;
1417
1418 td.in = in; td.out = out;
1419 ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1420 ff_filter_get_nb_threads(ctx)));
1421
1422 av_frame_free(&in);
1423 return ff_filter_frame(outlink, out);
1424 }
1425
activate(AVFilterContext * ctx)1426 static int activate(AVFilterContext *ctx)
1427 {
1428 AVFilterLink *inlink = ctx->inputs[0];
1429 AVFilterLink *outlink = ctx->outputs[0];
1430 AVFrame *in = NULL;
1431 int ret;
1432
1433 FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1434
1435 ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1436 if (ret < 0)
1437 return ret;
1438
1439 if (ret > 0)
1440 return filter_frame(inlink, in);
1441
1442 FF_FILTER_FORWARD_STATUS(inlink, outlink);
1443 FF_FILTER_FORWARD_WANTED(outlink, inlink);
1444
1445 return FFERROR_NOT_READY;
1446 }
1447
init(AVFilterContext * ctx)1448 static av_cold int init(AVFilterContext *ctx)
1449 {
1450 AudioRNNContext *s = ctx->priv;
1451 FILE *f;
1452
1453 s->fdsp = avpriv_float_dsp_alloc(0);
1454 if (!s->fdsp)
1455 return AVERROR(ENOMEM);
1456
1457 if (!s->model_name)
1458 return AVERROR(EINVAL);
1459 f = av_fopen_utf8(s->model_name, "r");
1460 if (!f)
1461 return AVERROR(EINVAL);
1462
1463 s->model = rnnoise_model_from_file(f);
1464 fclose(f);
1465 if (!s->model)
1466 return AVERROR(EINVAL);
1467
1468 for (int i = 0; i < FRAME_SIZE; i++) {
1469 s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1470 s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1471 }
1472
1473 for (int i = 0; i < NB_BANDS; i++) {
1474 for (int j = 0; j < NB_BANDS; j++) {
1475 s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1476 if (j == 0)
1477 s->dct_table[j][i] *= sqrtf(.5);
1478 }
1479 }
1480
1481 return 0;
1482 }
1483
uninit(AVFilterContext * ctx)1484 static av_cold void uninit(AVFilterContext *ctx)
1485 {
1486 AudioRNNContext *s = ctx->priv;
1487
1488 av_freep(&s->fdsp);
1489 rnnoise_model_free(s->model);
1490 s->model = NULL;
1491
1492 if (s->st) {
1493 for (int ch = 0; ch < s->channels; ch++) {
1494 av_freep(&s->st[ch].rnn.vad_gru_state);
1495 av_freep(&s->st[ch].rnn.noise_gru_state);
1496 av_freep(&s->st[ch].rnn.denoise_gru_state);
1497 av_tx_uninit(&s->st[ch].tx);
1498 av_tx_uninit(&s->st[ch].txi);
1499 }
1500 }
1501 av_freep(&s->st);
1502 }
1503
1504 static const AVFilterPad inputs[] = {
1505 {
1506 .name = "default",
1507 .type = AVMEDIA_TYPE_AUDIO,
1508 .config_props = config_input,
1509 },
1510 { NULL }
1511 };
1512
1513 static const AVFilterPad outputs[] = {
1514 {
1515 .name = "default",
1516 .type = AVMEDIA_TYPE_AUDIO,
1517 },
1518 { NULL }
1519 };
1520
1521 #define OFFSET(x) offsetof(AudioRNNContext, x)
1522 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1523
1524 static const AVOption arnndn_options[] = {
1525 { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1526 { "m", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1527 { NULL }
1528 };
1529
1530 AVFILTER_DEFINE_CLASS(arnndn);
1531
1532 AVFilter ff_af_arnndn = {
1533 .name = "arnndn",
1534 .description = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1535 .query_formats = query_formats,
1536 .priv_size = sizeof(AudioRNNContext),
1537 .priv_class = &arnndn_class,
1538 .activate = activate,
1539 .init = init,
1540 .uninit = uninit,
1541 .inputs = inputs,
1542 .outputs = outputs,
1543 .flags = AVFILTER_FLAG_SLICE_THREADS,
1544 };
1545