1 use std::fs::File;
2 use std::io::{BufReader, BufWriter, Read, Seek, Write};
3 use std::path::Path;
4 use std::str::FromStr;
5
6 use anyhow::{anyhow, Context, Error};
7 use clap::{crate_version, App, Arg};
8 use dasp_interpolate::{sinc::Sinc, Interpolator};
9 use dasp_ring_buffer::Fixed;
10 use hound::{SampleFormat, WavReader, WavSpec, WavWriter};
11
12 use nnnoiseless::{DenoiseState, RnnModel};
13
14 const FRAME_SIZE: usize = DenoiseState::FRAME_SIZE;
15
16 trait ReadSample {
next_sample(&mut self) -> Result<Option<&[f32]>, Error>17 fn next_sample(&mut self) -> Result<Option<&[f32]>, Error>;
channels(&self) -> usize18 fn channels(&self) -> usize;
19
resampled(self, ratio: f64) -> Resample<Self> where Self: Sized,20 fn resampled(self, ratio: f64) -> Resample<Self>
21 where
22 Self: Sized,
23 {
24 Resample {
25 sinc: (0..self.channels())
26 .map(|_| Sinc::new(Fixed::from([0.0; 16])))
27 .collect(),
28 buf: vec![0.0; self.channels()],
29 ratio,
30 pos: 0.0,
31 read: self,
32 }
33 }
34 }
35
36 // TODO: support either endianness
37 struct RawSampleIter<R: Read> {
38 bytes: std::io::Bytes<R>,
39 }
40
41 struct Resample<RS: ReadSample> {
42 sinc: Vec<Sinc<[f32; 16]>>,
43 buf: Vec<f32>,
44 ratio: f64,
45 pos: f64,
46 read: RS,
47 }
48
49 struct IterReadSample<I> {
50 samples: I,
51 buf: Vec<f32>,
52 }
53
54 impl<I: Iterator<Item = Result<f32, Error>>> IterReadSample<I> {
new(iter: I, channels: usize) -> IterReadSample<I>55 fn new(iter: I, channels: usize) -> IterReadSample<I> {
56 IterReadSample {
57 samples: iter,
58 buf: vec![0.0; channels],
59 }
60 }
61 }
62
63 impl<R: Read> Iterator for RawSampleIter<R> {
64 type Item = Result<f32, Error>;
65
next(&mut self) -> Option<Result<f32, Error>>66 fn next(&mut self) -> Option<Result<f32, Error>> {
67 match self.bytes.next() {
68 None => None,
69 Some(Err(e)) => Some(Err(e.into())),
70 Some(Ok(a)) => match self.bytes.next() {
71 None => Some(Err(anyhow!(
72 "Unexpected end of input (expected an even number of bytes)"
73 ))),
74 Some(Err(e)) => Some(Err(e.into())),
75 Some(Ok(b)) => Some(Ok(i16::from_le_bytes([a, b]) as f32)),
76 },
77 }
78 }
79 }
80
81 impl<I: Iterator<Item = Result<f32, Error>>> ReadSample for IterReadSample<I> {
next_sample(&mut self) -> Result<Option<&[f32]>, Error>82 fn next_sample(&mut self) -> Result<Option<&[f32]>, Error> {
83 for (i, sample) in self.buf.iter_mut().enumerate() {
84 match self.samples.next() {
85 None => {
86 if i == 0 {
87 return Ok(None);
88 } else {
89 return Err(anyhow!(
90 "Unexpected end of input (expected a multiple of {} samples)",
91 self.buf.len()
92 ));
93 }
94 }
95 Some(Err(e)) => return Err(e),
96 Some(Ok(x)) => *sample = x,
97 }
98 }
99 Ok(Some(&self.buf[..]))
100 }
101
channels(&self) -> usize102 fn channels(&self) -> usize {
103 self.buf.len()
104 }
105 }
106
107 impl<RS: ReadSample> ReadSample for Resample<RS> {
next_sample(&mut self) -> Result<Option<&[f32]>, Error>108 fn next_sample(&mut self) -> Result<Option<&[f32]>, Error> {
109 self.pos += self.ratio;
110 while self.pos >= 1.0 {
111 self.pos -= 1.0;
112
113 if let Some(buf) = self.read.next_sample()? {
114 for (s, &x) in self.sinc.iter_mut().zip(buf) {
115 s.next_source_frame(x);
116 }
117 } else {
118 return Ok(None);
119 }
120 }
121
122 for (s, x) in self.sinc.iter().zip(&mut self.buf) {
123 *x = s.interpolate(self.pos);
124 }
125
126 Ok(Some(&self.buf[..]))
127 }
128
channels(&self) -> usize129 fn channels(&self) -> usize {
130 self.read.channels()
131 }
132 }
133
134 trait FrameWriter {
write_frame(&mut self, buf: &[f32]) -> Result<(), Error>135 fn write_frame(&mut self, buf: &[f32]) -> Result<(), Error>;
finalize(&mut self) -> Result<(), Error>136 fn finalize(&mut self) -> Result<(), Error>;
137 }
138
139 struct RawFrameWriter<W: Write> {
140 writer: W,
141 buf: Vec<u8>,
142 }
143
144 struct WavFrameWriter<W: Write + Seek> {
145 writer: WavWriter<W>,
146 }
147
148 impl<W: Write> FrameWriter for RawFrameWriter<W> {
write_frame(&mut self, buf: &[f32]) -> Result<(), Error>149 fn write_frame(&mut self, buf: &[f32]) -> Result<(), Error> {
150 assert_eq!(buf.len() * 2, self.buf.len());
151 for (dst, src) in self.buf.chunks_mut(2).zip(buf) {
152 let bytes =
153 (src.max(i16::MIN as f32).min(i16::MAX as f32).round() as i16).to_le_bytes();
154 dst[0] = bytes[0];
155 dst[1] = bytes[1];
156 }
157 self.writer.write_all(&self.buf[..]).map_err(|e| e.into())
158 }
159
finalize(&mut self) -> Result<(), Error>160 fn finalize(&mut self) -> Result<(), Error> {
161 self.writer.flush()?;
162 Ok(())
163 }
164 }
165
166 impl<W: Write + Seek> FrameWriter for WavFrameWriter<W> {
write_frame(&mut self, buf: &[f32]) -> Result<(), Error>167 fn write_frame(&mut self, buf: &[f32]) -> Result<(), Error> {
168 let mut w = self.writer.get_i16_writer(buf.len() as u32);
169 for &x in buf {
170 w.write_sample(x.max(i16::MIN as f32).min(i16::MAX as f32).round() as i16);
171 }
172 w.flush().map_err(|e| e.into())
173 }
174
finalize(&mut self) -> Result<(), Error>175 fn finalize(&mut self) -> Result<(), Error> {
176 self.writer.flush().map_err(|e| e.into())
177 }
178 }
179
raw_samples<R: Read + 'static>(r: R, channels: usize, sample_rate: f64) -> Box<dyn ReadSample>180 fn raw_samples<R: Read + 'static>(r: R, channels: usize, sample_rate: f64) -> Box<dyn ReadSample> {
181 let raw = IterReadSample::new(RawSampleIter { bytes: r.bytes() }, channels);
182
183 if sample_rate != 48_000.0 {
184 Box::new(raw.resampled(sample_rate / 48_000.0))
185 } else {
186 Box::new(raw)
187 }
188 }
189
wav_samples<R: Read + 'static>(wav: WavReader<R>) -> Box<dyn ReadSample>190 fn wav_samples<R: Read + 'static>(wav: WavReader<R>) -> Box<dyn ReadSample> {
191 let sample_rate = wav.spec().sample_rate as f64;
192 let channels = wav.spec().channels as usize;
193 match wav.spec().sample_format {
194 SampleFormat::Int => {
195 let bits_per_sample = wav.spec().bits_per_sample;
196 assert!(bits_per_sample <= 32);
197
198 let iter = wav.into_samples::<i32>().map(move |s| {
199 s.map(|s| {
200 if bits_per_sample < 16 {
201 (s << (16 - bits_per_sample)) as f32
202 } else {
203 (s >> (bits_per_sample - 16)) as f32
204 }
205 })
206 .map_err(|e| e.into())
207 });
208
209 let read_sample = IterReadSample::new(iter, channels);
210 if sample_rate != 48_000.0 {
211 Box::new(read_sample.resampled(sample_rate / 48_000.0))
212 } else {
213 Box::new(read_sample)
214 }
215 }
216 SampleFormat::Float => {
217 let iter = wav
218 .into_samples::<f32>()
219 .map(|s| s.map(|s| s * 32767.0).map_err(|e| e.into()));
220
221 let read_sample = IterReadSample::new(iter, channels);
222 if sample_rate != 48_000.0 {
223 Box::new(read_sample.resampled(sample_rate / 48_000.0))
224 } else {
225 Box::new(read_sample)
226 }
227 }
228 }
229 }
230
main() -> Result<(), Box<dyn std::error::Error>>231 fn main() -> Result<(), Box<dyn std::error::Error>> {
232 let matches = App::new("nnnoiseless")
233 .version(crate_version!())
234 .about("Remove noise from audio files")
235 .arg(
236 Arg::with_name("INPUT")
237 .help("input audio file")
238 .required(true),
239 )
240 .arg(
241 Arg::with_name("OUTPUT")
242 .help("output audio file")
243 .required(true),
244 )
245 .arg(Arg::with_name("wav-in").long("wav-in").help(
246 "if set, the input is a wav file (default is to detect wav files by their filename)",
247 ))
248 .arg(Arg::with_name("wav-out").long("wav-out").help(
249 "if set, the output is a wav file (default is to detect wav files by their filename)",
250 ))
251 .arg(
252 Arg::with_name("sample-rate")
253 .long("sample-rate")
254 .help("for raw input, the sample rate of the input (defaults to 48kHz)")
255 .takes_value(true),
256 )
257 .arg(
258 Arg::with_name("channels")
259 .long("channels")
260 .help("for raw input, the number of channels (defaults to 1)")
261 .takes_value(true),
262 )
263 .arg(
264 Arg::with_name("model")
265 .long("model")
266 .help("path to a custom model file")
267 .takes_value(true),
268 )
269 .get_matches();
270
271 let in_name = matches.value_of("INPUT").unwrap();
272 let out_name = matches.value_of("OUTPUT").unwrap();
273 let in_file = BufReader::new(
274 File::open(in_name)
275 .with_context(|| format!("Failed to open input file \"{}\"", in_name))?,
276 );
277 let out_file = BufWriter::new(
278 File::create(out_name)
279 .with_context(|| format!("Failed to open output file \"{}\"", out_name))?,
280 );
281 let in_wav =
282 matches.is_present("wav-in") || Path::new(in_name).extension() == Some("wav".as_ref());
283 let out_wav =
284 matches.is_present("wav-out") || Path::new(out_name).extension() == Some("wav".as_ref());
285
286 let (mut samples, channels) = if in_wav {
287 let wav_reader = WavReader::new(in_file)?;
288 let channels = wav_reader.spec().channels;
289 (wav_samples(wav_reader), channels)
290 } else {
291 // TODO: report parse errors
292 let sample_rate = matches
293 .value_of("sample-rate")
294 .and_then(|s| f64::from_str(s).ok())
295 .unwrap_or(48_000.0);
296 let channels = matches
297 .value_of("channels")
298 .and_then(|s| u16::from_str(s).ok())
299 .unwrap_or(1);
300 (
301 raw_samples(in_file, channels as usize, sample_rate),
302 channels,
303 )
304 };
305
306 let mut frame_writer: Box<dyn FrameWriter> = if out_wav {
307 let spec = WavSpec {
308 channels,
309 sample_rate: 48_000,
310 bits_per_sample: 16,
311 sample_format: SampleFormat::Int,
312 };
313 let writer = WavWriter::new(out_file, spec)?;
314 Box::new(WavFrameWriter { writer })
315 } else {
316 Box::new(RawFrameWriter {
317 writer: out_file,
318 buf: vec![0; FRAME_SIZE * 2],
319 })
320 };
321
322 let model = if let Some(model_path) = matches.value_of("model") {
323 RnnModel::from_read(BufReader::new(
324 File::open(model_path).context("Failed to open model file")?,
325 ))
326 .context("Failed to read model file")?
327 } else {
328 RnnModel::default()
329 };
330 let channels = channels as usize;
331 let mut in_bufs = vec![vec![0.0; FRAME_SIZE]; channels];
332 let mut out_bufs = vec![vec![0.0; FRAME_SIZE]; channels];
333 let mut out_buf = vec![0.0; FRAME_SIZE * channels];
334 let mut states = vec![DenoiseState::with_model(&model); channels];
335 let mut first = true;
336 'outer: loop {
337 for i in 0..FRAME_SIZE {
338 if let Some(buf) = samples.next_sample()? {
339 for j in 0..channels {
340 in_bufs[j][i] = buf[j];
341 }
342 } else {
343 break 'outer;
344 }
345 }
346
347 for j in 0..channels {
348 states[j].process_frame(&mut out_bufs[j], &in_bufs[j]);
349 }
350 if !first {
351 for i in 0..FRAME_SIZE {
352 for j in 0..channels {
353 out_buf[i * channels + j] = out_bufs[j][i];
354 }
355 }
356 frame_writer.write_frame(&out_buf[..])?;
357 }
358 first = false;
359 }
360 frame_writer.finalize()?;
361
362 Ok(())
363 }
364