1 /* Copyright (c) 2018 Gregor Richards
2  * Copyright (c) 2017 Mozilla */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7 
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10 
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14 
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27 
28 #ifdef HAVE_CONFIG_H
29 #include "config.h"
SampleEditMarkers()30 #endif
31 
32 #include <stdlib.h>
33 #include <string.h>
34 #include <stdio.h>
35 #include "kiss_fft.h"
36 #include "common.h"
37 #include <math.h>
38 #include "rnnoise-nu.h"
39 #include "pitch.h"
40 #include "arch.h"
41 #include "rnn.h"
42 #include "rnn_data.h"
43 
44 #define DEFAULT_SAMPLE_RATE 48000
45 
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
50 
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
55 
56 #define SQUARE(x) ((x)*(x))
57 
58 #define SMOOTH_BANDS 1
59 
60 #define NB_BAND_BOUNDARIES 22
61 
62 #if SMOOTH_BANDS
63 #define NB_BANDS 22
64 #else
65 #define NB_BANDS 21
66 #endif
67 
68 #define CEPS_MEM 8
69 #define NB_DELTA_CEPS 6
70 
71 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
72 
73 /* We don't allow max attenuation to be more than 60dB */
74 #define MIN_MAX_ATTENUATION 0.000001f
75 
76 
77 #ifndef TRAINING
78 #define TRAINING 0
79 #endif
80 
81 
82 /* cb is the default model */
83 extern const struct RNNModel model_cb;
clip_end(bool & v)84 
85 
86 static const opus_int16 eband5ms[] = {
87 /*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
88   0,  1,  2,  3,  4,  5,  6,  7,  8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
89 };
set_clip_start(float pos)90 
91 
92 struct DenoiseState {
93   int init;
94   kiss_fft_state *kfft;
95   float half_window[FRAME_SIZE];
96   float dct_table[NB_BANDS*NB_BANDS];
97 
98   int sample_rate;
99 
100   float analysis_mem[FRAME_SIZE];
101   float cepstral_mem[CEPS_MEM][NB_BANDS];
102   int memid;
103   float synthesis_mem[FRAME_SIZE];
104   float pitch_buf[PITCH_BUF_SIZE];
105   float pitch_enh_buf[PITCH_BUF_SIZE];
106 
107   /* Bands adjusted for the sample rate */
108   opus_int16 band_bins[NB_BAND_BOUNDARIES];
109 
110   float last_gain;
111   int last_period;
112   float mem_hp_x[2];
113   float lastg[NB_BANDS];
114   RNNState rnn;
115 
116   float max_attenuation;
117 };
118 
119 #if SMOOTH_BANDS
120 void compute_band_energy(DenoiseState *st, float *bandE, const kiss_fft_cpx *X) {
121   int i;
122   float sum[NB_BANDS] = {0};
123   for (i=0;i<NB_BANDS-1;i++)
124   {
125     int j;
126     int band_size;
127     band_size = st->band_bins[i+1] - st->band_bins[i];
128     for (j=0;j<band_size;j++) {
129       float tmp;
130       float frac = (float)j/band_size;
131       tmp = SQUARE(X[st->band_bins[i] + j].r);
132       tmp += SQUARE(X[st->band_bins[i] + j].i);
133       sum[i] += (1-frac)*tmp;
134       sum[i+1] += frac*tmp;
135     }
136   }
137   sum[0] *= 2;
138   sum[NB_BANDS-1] *= 2;
139   for (i=0;i<NB_BANDS;i++)
140   {
141     bandE[i] = sum[i];
142   }
143 }
144 
145 void compute_band_corr(DenoiseState *st, float *bandE, const kiss_fft_cpx *X, const kiss_fft_cpx *P) {
146   int i;
147   float sum[NB_BANDS] = {0};
148   for (i=0;i<NB_BANDS-1;i++)
149   {
150     int j;
151     int band_size;
152     band_size = st->band_bins[i+1] - st->band_bins[i];
153     for (j=0;j<band_size;j++) {
154       float tmp;
155       float frac = (float)j/band_size;
156       tmp = X[st->band_bins[i] + j].r * P[st->band_bins[i] + j].r;
157       tmp += X[st->band_bins[i] + j].i * P[st->band_bins[i] + j].i;
158       sum[i] += (1-frac)*tmp;
159       sum[i+1] += frac*tmp;
160     }
161   }
162   sum[0] *= 2;
163   sum[NB_BANDS-1] *= 2;
164   for (i=0;i<NB_BANDS;i++)
165   {
166     bandE[i] = sum[i];
167   }
168 }
169 
170 void interp_band_gain(DenoiseState *st, float *g, const float *bandE) {
171   int i;
172   float prev, cur, next;
173   memset(g, 0, FREQ_SIZE);
174   prev = cur = next = bandE[0]/2;
175   for (i=0;i<NB_BANDS-1;i++)
176   {
177     int j;
178     int band_size;
179 
180     /* Adjust our inputs to the surrounding bands */
181     prev = cur;
182     cur = next;
183     next = bandE[i+1]/2;
184 
185     band_size = st->band_bins[i+1] - st->band_bins[i];
186     for (j=0;j<band_size;j++) {
187       float frac = (float)j/band_size;
188       g[st->band_bins[i] + j] = (1-frac)*prev + frac*next + cur;
189     }
190   }
191 }
192 #else
193 void compute_band_energy(DenoiseState *st, float *bandE, const kiss_fft_cpx *X) {
194   int i;
195   for (i=0;i<NB_BANDS;i++)
196   {
197     int j;
198     opus_val32 sum = 0;
199     for (j=0;j<(st->band_bins[i+1] - st->band_bins[i]);j++) {
200       sum += SQUARE(X[st->band_bins[i] + j].r);
201       sum += SQUARE(X[st->band_bins[i] + j].i);
202     }
203     bandE[i] = sum;
204   }
205 }
206 
207 void interp_band_gain(DenoiseState *st, float *g, const float *bandE) {
208   int i;
209   memset(g, 0, FREQ_SIZE);
210   for (i=0;i<NB_BANDS;i++)
211   {
212     int j;
213     for (j=0;j<(st->band_bins[i+1] - st->band_bins[i]);j++)
214       g[st->band_bins[i] + j] = bandE[i];
215   }
216 }
217 #endif
218 
219 
220 static void check_init(DenoiseState *st) {
221   int i;
222   if (st->init) return;
223   /* FIXME: Deallocate this! */
224   st->kfft = opus_fft_alloc_twiddles(2*FRAME_SIZE, NULL, NULL, NULL, 0);
225 
226   /* Get the sample rate set up */
227   if (st->sample_rate <= 0) st->sample_rate = DEFAULT_SAMPLE_RATE;
228 
229   /* Adjust the bins for the sample rate */
230   for (i = 0; i < NB_BAND_BOUNDARIES; i++)
231     st->band_bins[i] = (((long) eband5ms[i]) << FRAME_SIZE_SHIFT) * DEFAULT_SAMPLE_RATE / st->sample_rate;
232 
233   /* Make sure nothing's above the Nyquist frequency */
234   for (i = 0; i < NB_BANDS; i++)
235     if (st->band_bins[i] >= FRAME_SIZE) st->band_bins[i] = FRAME_SIZE - 1;
236 
237   for (i=0;i<FRAME_SIZE;i++)
238     st->half_window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
239   for (i=0;i<NB_BANDS;i++) {
240     int j;
241     for (j=0;j<NB_BANDS;j++) {
242       st->dct_table[i*NB_BANDS + j] = cos((i+.5)*j*M_PI/NB_BANDS);
243       if (j==0) st->dct_table[i*NB_BANDS + j] *= sqrt(.5);
244     }
245   }
246   st->init = 1;
247 }
248 
249 static void dct(DenoiseState *st, float *out, const float *in) {
250   int i;
251   check_init(st);
252   for (i=0;i<NB_BANDS;i++) {
253     int j;
254     float sum = 0;
255     for (j=0;j<NB_BANDS;j++) {
256       sum += in[j] * st->dct_table[j*NB_BANDS + i];
257     }
258     out[i] = sum*sqrt(2./22);
259   }
260 }
261 
262 #if 0
263 static void idct(DenoiseState *st, float *out, const float *in) {
264   int i;
265   check_init(st);
266   for (i=0;i<NB_BANDS;i++) {
267     int j;
268     float sum = 0;
269     for (j=0;j<NB_BANDS;j++) {
270       sum += in[j] * st->dct_table[i*NB_BANDS + j];
271     }
272     out[i] = sum*sqrt(2./22);
273   }
274 }
275 #endif
276 
277 static void forward_transform(DenoiseState *st, kiss_fft_cpx *out, const float *in) {
278   int i;
279   kiss_fft_cpx x[WINDOW_SIZE];
280   kiss_fft_cpx y[WINDOW_SIZE];
281   check_init(st);
282   for (i=0;i<WINDOW_SIZE;i++) {
283     x[i].r = in[i];
284     x[i].i = 0;
285   }
286   opus_fft(st->kfft, x, y, 0);
287   for (i=0;i<FREQ_SIZE;i++) {
288     out[i] = y[i];
289   }
290 }
291 
292 static void inverse_transform(DenoiseState *st, float *out, const kiss_fft_cpx *in) {
293   int i;
294   kiss_fft_cpx x[WINDOW_SIZE];
295   kiss_fft_cpx y[WINDOW_SIZE];
296   check_init(st);
297   for (i=0;i<FREQ_SIZE;i++) {
298     x[i] = in[i];
299   }
300   for (;i<WINDOW_SIZE;i++) {
301     x[i].r = x[WINDOW_SIZE - i].r;
302     x[i].i = -x[WINDOW_SIZE - i].i;
303   }
304   opus_fft(st->kfft, x, y, 0);
305   /* output in reverse order for IFFT. */
306   out[0] = WINDOW_SIZE*y[0].r;
307   for (i=1;i<WINDOW_SIZE;i++) {
308     out[i] = WINDOW_SIZE*y[WINDOW_SIZE - i].r;
309   }
310 }
311 
312 static void apply_window(DenoiseState *st, float *x) {
313   int i;
314   check_init(st);
315   for (i=0;i<FRAME_SIZE;i++) {
316     x[i] *= st->half_window[i];
317     x[WINDOW_SIZE - 1 - i] *= st->half_window[i];
318   }
319 }
320 
321 int rnnoise_get_size() {
322   return sizeof(DenoiseState);
323 }
324 
325 int rnnoise_init(DenoiseState *st, RNNModel *model) {
326   memset(st, 0, sizeof(*st));
327   if (model)
328     st->rnn.model = model;
329   else
330     st->rnn.model = &model_cb;
331   st->rnn.vad_gru_state = calloc(sizeof(float), st->rnn.model->vad_gru_size);
332   st->rnn.noise_gru_state = calloc(sizeof(float), st->rnn.model->noise_gru_size);
333   st->rnn.denoise_gru_state = calloc(sizeof(float), st->rnn.model->denoise_gru_size);
334   return 0;
335 }
336 
337 DenoiseState *rnnoise_create(RNNModel *model) {
338   DenoiseState *st;
339   st = malloc(rnnoise_get_size());
340   rnnoise_init(st, model);
341   return st;
342 }
343 
344 void rnnoise_destroy(DenoiseState *st) {
345   if (st->init)
346     free(st->kfft);
347   free(st->rnn.vad_gru_state);
348   free(st->rnn.noise_gru_state);
349   free(st->rnn.denoise_gru_state);
350   free(st);
351 }
352 
353 #if TRAINING
354 int lowpass = FREQ_SIZE;
355 int band_lp = NB_BANDS;
356 #endif
357 
358 static void frame_analysis(DenoiseState *st, kiss_fft_cpx *X, float *Ex, const float *in) {
359   int i;
360   float x[WINDOW_SIZE];
361   RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
362   for (i=0;i<FRAME_SIZE;i++) x[FRAME_SIZE + i] = in[i];
363   RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
364   apply_window(st, x);
365   forward_transform(st, X, x);
366 #if TRAINING
367   for (i=lowpass;i<FREQ_SIZE;i++)
368     X[i].r = X[i].i = 0;
369 #endif
370   compute_band_energy(st, Ex, X);
371 }
372 
373 static int compute_frame_features(DenoiseState *st, kiss_fft_cpx *X, kiss_fft_cpx *P,
374                                   float *Ex, float *Ep, float *Exp, float *features, const float *in) {
375   int i;
376   float E = 0;
377   float *ceps_0, *ceps_1, *ceps_2;
378   float spec_variability = 0;
379   float Ly[NB_BANDS];
380   float p[WINDOW_SIZE];
381   float pitch_buf[PITCH_BUF_SIZE>>1];
382   int pitch_index;
383   float gain;
384   float *(pre[1]);
385   float tmp[NB_BANDS];
386   float follow, logMax;
387   frame_analysis(st, X, Ex, in);
388   RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
389   RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
390   pre[0] = &st->pitch_buf[0];
391   pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
392   pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
393                PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
394   pitch_index = PITCH_MAX_PERIOD-pitch_index;
395 
396   gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
397           PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
398   st->last_period = pitch_index;
399   st->last_gain = gain;
400   for (i=0;i<WINDOW_SIZE;i++)
401     p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
402   apply_window(st, p);
403   forward_transform(st, P, p);
404   compute_band_energy(st, Ep, P);
405 #if SMOOTH_BANDS
406   compute_band_corr(st, Exp, X, P);
407 #endif
408   for (i=0;i<NB_BANDS;i++) Exp[i] = Exp[i]/sqrt(.001+Ex[i]*Ep[i]);
409   dct(st, tmp, Exp);
410   for (i=0;i<NB_DELTA_CEPS;i++) features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
411   features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
412   features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
413   features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
414   logMax = -2;
415   follow = -2;
416   for (i=0;i<NB_BANDS;i++) {
417     Ly[i] = log10(1e-2+Ex[i]);
418     Ly[i] = MAX16(logMax-7, MAX16(follow-1.5, Ly[i]));
419     logMax = MAX16(logMax, Ly[i]);
420     follow = MAX16(follow-1.5, Ly[i]);
421     E += Ex[i];
422   }
423   if (!TRAINING && E < 0.04) {
424     /* If there's no audio, avoid messing up the state. */
425     RNN_CLEAR(features, NB_FEATURES);
426     return 1;
427   }
428   dct(st, features, Ly);
429   features[0] -= 12;
430   features[1] -= 4;
431   ceps_0 = st->cepstral_mem[st->memid];
432   ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
433   ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
434   for (i=0;i<NB_BANDS;i++) ceps_0[i] = features[i];
435   st->memid++;
436   for (i=0;i<NB_DELTA_CEPS;i++) {
437     features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
438     features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
439     features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
440   }
441   /* Spectral variability features. */
442   if (st->memid == CEPS_MEM) st->memid = 0;
443   for (i=0;i<CEPS_MEM;i++)
444   {
445     int j;
446     float mindist = 1e15f;
447     for (j=0;j<CEPS_MEM;j++)
448     {
449       int k;
450       float dist=0;
451       for (k=0;k<NB_BANDS;k++)
452       {
453         float tmp;
454         tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
455         dist += tmp*tmp;
456       }
457       if (j!=i)
458         mindist = MIN32(mindist, dist);
459     }
460     spec_variability += mindist;
461   }
462   features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
463   return TRAINING && E < 0.1;
464 }
465 
466 static void frame_synthesis(DenoiseState *st, float *out, const kiss_fft_cpx *y) {
467   float x[WINDOW_SIZE];
468   int i;
469   inverse_transform(st, x, y);
470   apply_window(st, x);
471   for (i=0;i<FRAME_SIZE;i++) out[i] = x[i] + st->synthesis_mem[i];
472   RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
473 }
474 
475 static void biquad(float *y, float mem[2], const float *x, const float *b, const float *a, int N) {
476   int i;
477   for (i=0;i<N;i++) {
478     float xi, yi;
479     xi = x[i];
480     yi = x[i] + mem[0];
481     mem[0] = mem[1] + (b[0]*(double)xi - a[0]*(double)yi);
482     mem[1] = (b[1]*(double)xi - a[1]*(double)yi);
483     y[i] = yi;
484   }
485 }
486 
487 void pitch_filter(DenoiseState *st, kiss_fft_cpx *X, const kiss_fft_cpx *P, const float *Ex, const float *Ep,
488                   const float *Exp, const float *g) {
489   int i;
490   float r[NB_BANDS];
491   float rf[FREQ_SIZE] = {0};
492   for (i=0;i<NB_BANDS;i++) {
493 #if 0
494     if (Exp[i]>g[i]) r[i] = 1;
495     else r[i] = Exp[i]*(1-g[i])/(.001 + g[i]*(1-Exp[i]));
496     r[i] = MIN16(1, MAX16(0, r[i]));
497 #else
498     if (Exp[i]>g[i]) r[i] = 1;
499     else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
500     r[i] = sqrt(MIN16(1, MAX16(0, r[i])));
501 #endif
502     r[i] *= sqrt(Ex[i]/(1e-8+Ep[i]));
503   }
504   interp_band_gain(st, rf, r);
505   for (i=0;i<FREQ_SIZE;i++) {
506     X[i].r += rf[i]*P[i].r;
507     X[i].i += rf[i]*P[i].i;
508   }
509   float newE[NB_BANDS];
510   compute_band_energy(st, newE, X);
511   float norm[NB_BANDS];
512   float normf[FREQ_SIZE]={0};
513   for (i=0;i<NB_BANDS;i++) {
514     norm[i] = sqrt(Ex[i]/(1e-8+newE[i]));
515   }
516   interp_band_gain(st, normf, norm);
517   for (i=0;i<FREQ_SIZE;i++) {
518     X[i].r *= normf[i];
519     X[i].i *= normf[i];
520   }
521 }
522 
523 float rnnoise_process_frame(DenoiseState *st, float *out, const float *in) {
524   int i;
525   kiss_fft_cpx X[FREQ_SIZE];
526   kiss_fft_cpx P[WINDOW_SIZE];
527   float x[FRAME_SIZE];
528   float Ex[NB_BANDS], Ep[NB_BANDS];
529   float Exp[NB_BANDS];
530   float features[NB_FEATURES];
531   float g[NB_BANDS];
532   float gf[FREQ_SIZE]={1};
533   float vad_prob = 0;
534   int silence;
535   static const float a_hp[2] = {-1.99599, 0.99600};
536   static const float b_hp[2] = {-2, 1};
537   biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
538   silence = compute_frame_features(st, X, P, Ex, Ep, Exp, features, x);
539 
540   if (!silence) {
541     compute_rnn(&st->rnn, g, &vad_prob, features);
542     pitch_filter(st, X, P, Ex, Ep, Exp, g);
543     for (i=0;i<NB_BANDS;i++) {
544       float alpha = .6f;
545       g[i] = MAX16(g[i], alpha*st->lastg[i]);
546       st->lastg[i] = g[i];
547     }
548 
549     /* Apply maximum attenuation (minimum value) */
550     if (st->max_attenuation) {
551       float min = 1, mult;
552       for (i=0;i<NB_BANDS;i++) {
553         if (g[i] < min) min = g[i];
554       }
555       if (min < st->max_attenuation) {
556         if (min < MIN_MAX_ATTENUATION)
557           min = MIN_MAX_ATTENUATION;
558         mult = (1.0f-st->max_attenuation) / (1.0f-min);
559         for (i=0;i<NB_BANDS;i++) {
560           if (g[i] < MIN_MAX_ATTENUATION) g[i] = MIN_MAX_ATTENUATION;
561           g[i] = 1.0f-((1.0f-g[i]) * mult);
562           st->lastg[i] = g[i];
563         }
564       }
565     }
566 
567     interp_band_gain(st, gf, g);
568 #if 1
569     for (i=0;i<FREQ_SIZE;i++) {
570       X[i].r *= gf[i];
571       X[i].i *= gf[i];
572     }
573 #endif
574   }
575 
576   frame_synthesis(st, out, X);
577   return vad_prob;
578 }
579 
580 void rnnoise_set_param(DenoiseState *st, int param, float value)
581 {
582   switch (param) {
583     case RNNOISE_PARAM_MAX_ATTENUATION:
584       if ((value > MIN_MAX_ATTENUATION && value <= 1) || value == 0)
585         st->max_attenuation = value;
586       else
587         st->max_attenuation = MIN_MAX_ATTENUATION;
588       break;
589 
590     case RNNOISE_PARAM_SAMPLE_RATE:
591       if (value <= 0)
592         st->sample_rate = 0;
593       else
594         st->sample_rate = value;
595       break;
596   }
597 }
598 
599 #if TRAINING
600 
601 static float uni_rand() {
602   return rand()/(double)RAND_MAX-.5;
603 }
604 
605 static void rand_resp(float *a, float *b) {
606   a[0] = .75*uni_rand();
607   a[1] = .75*uni_rand();
608   b[0] = .75*uni_rand();
609   b[1] = .75*uni_rand();
610 }
611 
612 int main(int argc, char **argv) {
613   int i;
614   int count=0;
615   static const float a_hp[2] = {-1.99599, 0.99600};
616   static const float b_hp[2] = {-2, 1};
617   float a_noise[2] = {0};
618   float b_noise[2] = {0};
619   float a_sig[2] = {0};
620   float b_sig[2] = {0};
621   float mem_hp_x[2]={0};
622   float mem_hp_n[2]={0};
623   float mem_resp_x[2]={0};
624   float mem_resp_n[2]={0};
625   float x[FRAME_SIZE];
626   float n[FRAME_SIZE];
627   float xn[FRAME_SIZE];
628   int vad_cnt=0;
629   int gain_change_count=0;
630   float speech_gain = 1, noise_gain = 1;
631   FILE *f1, *f2;
632   int maxCount;
633   DenoiseState *st;
634   DenoiseState *noise_state;
635   DenoiseState *noisy;
636   st = rnnoise_create(NULL);
637   noise_state = rnnoise_create(NULL);
638   noisy = rnnoise_create(NULL);
639   if (argc!=4) {
640     fprintf(stderr, "usage: %s <speech> <noise> <count>\n", argv[0]);
641     return 1;
642   }
643   f1 = fopen(argv[1], "r");
644   f2 = fopen(argv[2], "r");
645   maxCount = atoi(argv[3]);
646   for(i=0;i<150;i++) {
647     short tmp[FRAME_SIZE];
648     fread(tmp, sizeof(short), FRAME_SIZE, f2);
649   }
650   while (1) {
651     kiss_fft_cpx X[FREQ_SIZE], Y[FREQ_SIZE], N[FREQ_SIZE], P[WINDOW_SIZE];
652     float Ex[NB_BANDS], Ey[NB_BANDS], En[NB_BANDS], Ep[NB_BANDS];
653     float Exp[NB_BANDS];
654     float Ln[NB_BANDS];
655     float features[NB_FEATURES];
656     float g[NB_BANDS];
657     short tmp[FRAME_SIZE];
658     float vad=0;
659     float E=0;
660     if (count==maxCount) break;
661     if ((count%1000)==0) fprintf(stderr, "%d\r", count);
662     if (++gain_change_count > 2821) {
663       speech_gain = pow(10., (-40+(rand()%60))/20.);
664       noise_gain = pow(10., (-30+(rand()%50))/20.);
665       if (rand()%10==0) noise_gain = 0;
666       noise_gain *= speech_gain;
667       if (rand()%10==0) speech_gain = 0;
668       gain_change_count = 0;
669       rand_resp(a_noise, b_noise);
670       rand_resp(a_sig, b_sig);
671       lowpass = FREQ_SIZE * 3000./24000. * pow(50., rand()/(double)RAND_MAX);
672       for (i=0;i<NB_BANDS;i++) {
673         if (eband5ms[i]<<FRAME_SIZE_SHIFT > lowpass) {
674           band_lp = i;
675           break;
676         }
677       }
678     }
679     if (speech_gain != 0) {
680       fread(tmp, sizeof(short), FRAME_SIZE, f1);
681       if (feof(f1)) {
682         rewind(f1);
683         fread(tmp, sizeof(short), FRAME_SIZE, f1);
684       }
685       for (i=0;i<FRAME_SIZE;i++) x[i] = speech_gain*tmp[i];
686       for (i=0;i<FRAME_SIZE;i++) E += tmp[i]*(float)tmp[i];
687     } else {
688       for (i=0;i<FRAME_SIZE;i++) x[i] = 0;
689       E = 0;
690     }
691     if (noise_gain!=0) {
692       fread(tmp, sizeof(short), FRAME_SIZE, f2);
693       if (feof(f2)) {
694         rewind(f2);
695         fread(tmp, sizeof(short), FRAME_SIZE, f2);
696       }
697       for (i=0;i<FRAME_SIZE;i++) n[i] = noise_gain*tmp[i];
698     } else {
699       for (i=0;i<FRAME_SIZE;i++) n[i] = 0;
700     }
701     biquad(x, mem_hp_x, x, b_hp, a_hp, FRAME_SIZE);
702     biquad(x, mem_resp_x, x, b_sig, a_sig, FRAME_SIZE);
703     biquad(n, mem_hp_n, n, b_hp, a_hp, FRAME_SIZE);
704     biquad(n, mem_resp_n, n, b_noise, a_noise, FRAME_SIZE);
705     for (i=0;i<FRAME_SIZE;i++) xn[i] = x[i] + n[i];
706     if (E > 1e9f) {
707       vad_cnt=0;
708     } else if (E > 1e8f) {
709       vad_cnt -= 5;
710     } else if (E > 1e7f) {
711       vad_cnt++;
712     } else {
713       vad_cnt+=2;
714     }
715     if (vad_cnt < 0) vad_cnt = 0;
716     if (vad_cnt > 15) vad_cnt = 15;
717 
718     if (vad_cnt >= 10) vad = 0;
719     else if (vad_cnt > 0) vad = 0.5f;
720     else vad = 1.f;
721 
722     frame_analysis(st, Y, Ey, x);
723     frame_analysis(noise_state, N, En, n);
724     for (i=0;i<NB_BANDS;i++) Ln[i] = log10(1e-2+En[i]);
725     int silence = compute_frame_features(noisy, X, P, Ex, Ep, Exp, features, xn);
726     pitch_filter(st, X, P, Ex, Ep, Exp, g);
727     //printf("%f %d\n", noisy->last_gain, noisy->last_period);
728     for (i=0;i<NB_BANDS;i++) {
729       g[i] = sqrt((Ey[i]+1e-3)/(Ex[i]+1e-3));
730       if (g[i] > 1) g[i] = 1;
731       if (silence || i > band_lp) g[i] = -1;
732       if (Ey[i] < 5e-2 && Ex[i] < 5e-2) g[i] = -1;
733       if (vad==0 && noise_gain==0) g[i] = -1;
734     }
735     count++;
736 #if 1
737     fwrite(features, sizeof(float), NB_FEATURES, stdout);
738     fwrite(g, sizeof(float), NB_BANDS, stdout);
739     fwrite(Ln, sizeof(float), NB_BANDS, stdout);
740     fwrite(&vad, sizeof(float), 1, stdout);
741 #endif
742   }
743   fprintf(stderr, "matrix size: %d x %d\n", count, NB_FEATURES + 2*NB_BANDS + 1);
744   fclose(f1);
745   fclose(f2);
746   return 0;
747 }
748 
749 #endif
750