1 #include "STFT.h"
2 
3 using namespace openshot;
4 
setup(const int num_input_channels)5 void STFT::setup(const int num_input_channels)
6 {
7     num_channels = (num_input_channels > 0) ? num_input_channels : 1;
8 }
9 
updateParameters(const int new_fft_size,const int new_overlap,const int new_window_type)10 void STFT::updateParameters(const int new_fft_size, const int new_overlap, const int new_window_type)
11 {
12     updateFftSize(new_fft_size);
13     updateHopSize(new_overlap);
14     updateWindow(new_window_type);
15 }
16 
process(juce::AudioSampleBuffer & block)17 void STFT::process(juce::AudioSampleBuffer &block)
18 {
19     num_samples = block.getNumSamples();
20 
21     for (int channel = 0; channel < num_channels; ++channel) {
22         float *channel_data = block.getWritePointer(channel);
23 
24         current_input_buffer_write_position = input_buffer_write_position;
25         current_output_buffer_write_position = output_buffer_write_position;
26         current_output_buffer_read_position = output_buffer_read_position;
27         current_samples_since_last_FFT = samples_since_last_FFT;
28 
29         for (int sample = 0; sample < num_samples; ++sample) {
30             const float input_sample = channel_data[sample];
31 
32             input_buffer.setSample(channel, current_input_buffer_write_position, input_sample);
33             if (++current_input_buffer_write_position >= input_buffer_length)
34                 current_input_buffer_write_position = 0;
35             // diff
36             channel_data[sample] = output_buffer.getSample(channel, current_output_buffer_read_position);
37 
38             output_buffer.setSample(channel, current_output_buffer_read_position, 0.0f);
39             if (++current_output_buffer_read_position >= output_buffer_length)
40                 current_output_buffer_read_position = 0;
41 
42             if (++current_samples_since_last_FFT >= hop_size) {
43                 current_samples_since_last_FFT = 0;
44                 analysis(channel);
45                 modification(channel);
46                 synthesis(channel);
47             }
48         }
49     }
50 
51     input_buffer_write_position = current_input_buffer_write_position;
52     output_buffer_write_position = current_output_buffer_write_position;
53     output_buffer_read_position = current_output_buffer_read_position;
54     samples_since_last_FFT = current_samples_since_last_FFT;
55 }
56 
57 
updateFftSize(const int new_fft_size)58 void STFT::updateFftSize(const int new_fft_size)
59 {
60     if (new_fft_size != fft_size)
61     {
62         fft_size = new_fft_size;
63         fft = std::make_unique<juce::dsp::FFT>(log2(fft_size));
64 
65         input_buffer_length = fft_size;
66         input_buffer.clear();
67         input_buffer.setSize(num_channels, input_buffer_length);
68 
69         output_buffer_length = fft_size;
70         output_buffer.clear();
71         output_buffer.setSize(num_channels, output_buffer_length);
72 
73         fft_window.realloc(fft_size);
74         fft_window.clear(fft_size);
75 
76         time_domain_buffer.realloc(fft_size);
77         time_domain_buffer.clear(fft_size);
78 
79         frequency_domain_buffer.realloc(fft_size);
80         frequency_domain_buffer.clear(fft_size);
81 
82         input_buffer_write_position = 0;
83         output_buffer_write_position = 0;
84         output_buffer_read_position = 0;
85         samples_since_last_FFT = 0;
86     }
87 }
88 
updateHopSize(const int new_overlap)89 void STFT::updateHopSize(const int new_overlap)
90 {
91     if (new_overlap != overlap)
92     {
93         overlap = new_overlap;
94 
95         if (overlap != 0) {
96             hop_size = fft_size / overlap;
97             output_buffer_write_position = hop_size % output_buffer_length;
98         }
99     }
100 }
101 
102 
updateWindow(const int new_window_type)103 void STFT::updateWindow(const int new_window_type)
104 {
105     window_type = new_window_type;
106 
107     switch (window_type) {
108         case RECTANGULAR: {
109             for (int sample = 0; sample < fft_size; ++sample)
110                 fft_window[sample] = 1.0f;
111             break;
112         }
113         case BART_LETT: {
114             for (int sample = 0; sample < fft_size; ++sample)
115                 fft_window[sample] = 1.0f - fabs (2.0f * (float)sample / (float)(fft_size - 1) - 1.0f);
116             break;
117         }
118         case HANN: {
119             for (int sample = 0; sample < fft_size; ++sample)
120                 fft_window[sample] = 0.5f - 0.5f * cosf (2.0f * M_PI * (float)sample / (float)(fft_size - 1));
121             break;
122         }
123         case HAMMING: {
124             for (int sample = 0; sample < fft_size; ++sample)
125                 fft_window[sample] = 0.54f - 0.46f * cosf (2.0f * M_PI * (float)sample / (float)(fft_size - 1));
126             break;
127         }
128     }
129 
130     float window_sum = 0.0f;
131     for (int sample = 0; sample < fft_size; ++sample)
132         window_sum += fft_window[sample];
133 
134     window_scale_factor = 0.0f;
135     if (overlap != 0 && window_sum != 0.0f)
136         window_scale_factor = 1.0f / (float)overlap / window_sum * (float)fft_size;
137 }
138 
139 
140 
analysis(const int channel)141 void STFT::analysis(const int channel)
142 {
143     int input_buffer_index = current_input_buffer_write_position;
144     for (int index = 0; index < fft_size; ++index) {
145         time_domain_buffer[index].real(fft_window[index] * input_buffer.getSample(channel, input_buffer_index));
146         time_domain_buffer[index].imag(0.0f);
147 
148         if (++input_buffer_index >= input_buffer_length)
149             input_buffer_index = 0;
150     }
151 }
152 
modification(const int channel)153 void STFT::modification(const int channel)
154 {
155     fft->perform(time_domain_buffer, frequency_domain_buffer, false);
156 
157     for (int index = 0; index < fft_size / 2 + 1; ++index) {
158         float magnitude = abs(frequency_domain_buffer[index]);
159         float phase = arg(frequency_domain_buffer[index]);
160 
161         frequency_domain_buffer[index].real(magnitude * cosf (phase));
162         frequency_domain_buffer[index].imag(magnitude * sinf (phase));
163 
164         if (index > 0 && index < fft_size / 2) {
165             frequency_domain_buffer[fft_size - index].real(magnitude * cosf (phase));
166             frequency_domain_buffer[fft_size - index].imag(magnitude * sinf (-phase));
167         }
168     }
169 
170     fft->perform(frequency_domain_buffer, time_domain_buffer, true);
171 }
172 
synthesis(const int channel)173 void STFT::synthesis(const int channel)
174 {
175     int output_buffer_index = current_output_buffer_write_position;
176     for (int index = 0; index < fft_size; ++index) {
177         float output_sample = output_buffer.getSample(channel, output_buffer_index);
178         output_sample += time_domain_buffer[index].real() * window_scale_factor;
179         output_buffer.setSample(channel, output_buffer_index, output_sample);
180 
181         if (++output_buffer_index >= output_buffer_length)
182             output_buffer_index = 0;
183     }
184 
185     current_output_buffer_write_position += hop_size;
186     if (current_output_buffer_write_position >= output_buffer_length)
187         current_output_buffer_write_position = 0;
188 }