1 use std::fs::File;
2 use std::io::{BufReader, BufWriter, Read, Seek, Write};
3 use std::path::Path;
4 use std::str::FromStr;
5 
6 use anyhow::{anyhow, Context, Error};
7 use clap::{crate_version, App, Arg};
8 use dasp_interpolate::{sinc::Sinc, Interpolator};
9 use dasp_ring_buffer::Fixed;
10 use hound::{SampleFormat, WavReader, WavSpec, WavWriter};
11 
12 use nnnoiseless::{DenoiseState, RnnModel};
13 
14 const FRAME_SIZE: usize = DenoiseState::FRAME_SIZE;
15 
16 trait ReadSample {
next_sample(&mut self) -> Result<Option<&[f32]>, Error>17     fn next_sample(&mut self) -> Result<Option<&[f32]>, Error>;
channels(&self) -> usize18     fn channels(&self) -> usize;
19 
resampled(self, ratio: f64) -> Resample<Self> where Self: Sized,20     fn resampled(self, ratio: f64) -> Resample<Self>
21     where
22         Self: Sized,
23     {
24         Resample {
25             sinc: (0..self.channels())
26                 .map(|_| Sinc::new(Fixed::from([0.0; 16])))
27                 .collect(),
28             buf: vec![0.0; self.channels()],
29             ratio,
30             pos: 0.0,
31             read: self,
32         }
33     }
34 }
35 
36 // TODO: support either endianness
37 struct RawSampleIter<R: Read> {
38     bytes: std::io::Bytes<R>,
39 }
40 
41 struct Resample<RS: ReadSample> {
42     sinc: Vec<Sinc<[f32; 16]>>,
43     buf: Vec<f32>,
44     ratio: f64,
45     pos: f64,
46     read: RS,
47 }
48 
49 struct IterReadSample<I> {
50     samples: I,
51     buf: Vec<f32>,
52 }
53 
54 impl<I: Iterator<Item = Result<f32, Error>>> IterReadSample<I> {
new(iter: I, channels: usize) -> IterReadSample<I>55     fn new(iter: I, channels: usize) -> IterReadSample<I> {
56         IterReadSample {
57             samples: iter,
58             buf: vec![0.0; channels],
59         }
60     }
61 }
62 
63 impl<R: Read> Iterator for RawSampleIter<R> {
64     type Item = Result<f32, Error>;
65 
next(&mut self) -> Option<Result<f32, Error>>66     fn next(&mut self) -> Option<Result<f32, Error>> {
67         match self.bytes.next() {
68             None => None,
69             Some(Err(e)) => Some(Err(e.into())),
70             Some(Ok(a)) => match self.bytes.next() {
71                 None => Some(Err(anyhow!(
72                     "Unexpected end of input (expected an even number of bytes)"
73                 ))),
74                 Some(Err(e)) => Some(Err(e.into())),
75                 Some(Ok(b)) => Some(Ok(i16::from_le_bytes([a, b]) as f32)),
76             },
77         }
78     }
79 }
80 
81 impl<I: Iterator<Item = Result<f32, Error>>> ReadSample for IterReadSample<I> {
next_sample(&mut self) -> Result<Option<&[f32]>, Error>82     fn next_sample(&mut self) -> Result<Option<&[f32]>, Error> {
83         for (i, sample) in self.buf.iter_mut().enumerate() {
84             match self.samples.next() {
85                 None => {
86                     if i == 0 {
87                         return Ok(None);
88                     } else {
89                         return Err(anyhow!(
90                             "Unexpected end of input (expected a multiple of {} samples)",
91                             self.buf.len()
92                         ));
93                     }
94                 }
95                 Some(Err(e)) => return Err(e),
96                 Some(Ok(x)) => *sample = x,
97             }
98         }
99         Ok(Some(&self.buf[..]))
100     }
101 
channels(&self) -> usize102     fn channels(&self) -> usize {
103         self.buf.len()
104     }
105 }
106 
107 impl<RS: ReadSample> ReadSample for Resample<RS> {
next_sample(&mut self) -> Result<Option<&[f32]>, Error>108     fn next_sample(&mut self) -> Result<Option<&[f32]>, Error> {
109         self.pos += self.ratio;
110         while self.pos >= 1.0 {
111             self.pos -= 1.0;
112 
113             if let Some(buf) = self.read.next_sample()? {
114                 for (s, &x) in self.sinc.iter_mut().zip(buf) {
115                     s.next_source_frame(x);
116                 }
117             } else {
118                 return Ok(None);
119             }
120         }
121 
122         for (s, x) in self.sinc.iter().zip(&mut self.buf) {
123             *x = s.interpolate(self.pos);
124         }
125 
126         Ok(Some(&self.buf[..]))
127     }
128 
channels(&self) -> usize129     fn channels(&self) -> usize {
130         self.read.channels()
131     }
132 }
133 
134 trait FrameWriter {
write_frame(&mut self, buf: &[f32]) -> Result<(), Error>135     fn write_frame(&mut self, buf: &[f32]) -> Result<(), Error>;
finalize(&mut self) -> Result<(), Error>136     fn finalize(&mut self) -> Result<(), Error>;
137 }
138 
139 struct RawFrameWriter<W: Write> {
140     writer: W,
141     buf: Vec<u8>,
142 }
143 
144 struct WavFrameWriter<W: Write + Seek> {
145     writer: WavWriter<W>,
146 }
147 
148 impl<W: Write> FrameWriter for RawFrameWriter<W> {
write_frame(&mut self, buf: &[f32]) -> Result<(), Error>149     fn write_frame(&mut self, buf: &[f32]) -> Result<(), Error> {
150         assert_eq!(buf.len() * 2, self.buf.len());
151         for (dst, src) in self.buf.chunks_mut(2).zip(buf) {
152             let bytes =
153                 (src.max(i16::MIN as f32).min(i16::MAX as f32).round() as i16).to_le_bytes();
154             dst[0] = bytes[0];
155             dst[1] = bytes[1];
156         }
157         self.writer.write_all(&self.buf[..]).map_err(|e| e.into())
158     }
159 
finalize(&mut self) -> Result<(), Error>160     fn finalize(&mut self) -> Result<(), Error> {
161         self.writer.flush()?;
162         Ok(())
163     }
164 }
165 
166 impl<W: Write + Seek> FrameWriter for WavFrameWriter<W> {
write_frame(&mut self, buf: &[f32]) -> Result<(), Error>167     fn write_frame(&mut self, buf: &[f32]) -> Result<(), Error> {
168         let mut w = self.writer.get_i16_writer(buf.len() as u32);
169         for &x in buf {
170             w.write_sample(x.max(i16::MIN as f32).min(i16::MAX as f32).round() as i16);
171         }
172         w.flush().map_err(|e| e.into())
173     }
174 
finalize(&mut self) -> Result<(), Error>175     fn finalize(&mut self) -> Result<(), Error> {
176         self.writer.flush().map_err(|e| e.into())
177     }
178 }
179 
raw_samples<R: Read + 'static>(r: R, channels: usize, sample_rate: f64) -> Box<dyn ReadSample>180 fn raw_samples<R: Read + 'static>(r: R, channels: usize, sample_rate: f64) -> Box<dyn ReadSample> {
181     let raw = IterReadSample::new(RawSampleIter { bytes: r.bytes() }, channels);
182 
183     if sample_rate != 48_000.0 {
184         Box::new(raw.resampled(sample_rate / 48_000.0))
185     } else {
186         Box::new(raw)
187     }
188 }
189 
wav_samples<R: Read + 'static>(wav: WavReader<R>) -> Box<dyn ReadSample>190 fn wav_samples<R: Read + 'static>(wav: WavReader<R>) -> Box<dyn ReadSample> {
191     let sample_rate = wav.spec().sample_rate as f64;
192     let channels = wav.spec().channels as usize;
193     match wav.spec().sample_format {
194         SampleFormat::Int => {
195             let bits_per_sample = wav.spec().bits_per_sample;
196             assert!(bits_per_sample <= 32);
197 
198             let iter = wav.into_samples::<i32>().map(move |s| {
199                 s.map(|s| {
200                     if bits_per_sample < 16 {
201                         (s << (16 - bits_per_sample)) as f32
202                     } else {
203                         (s >> (bits_per_sample - 16)) as f32
204                     }
205                 })
206                 .map_err(|e| e.into())
207             });
208 
209             let read_sample = IterReadSample::new(iter, channels);
210             if sample_rate != 48_000.0 {
211                 Box::new(read_sample.resampled(sample_rate / 48_000.0))
212             } else {
213                 Box::new(read_sample)
214             }
215         }
216         SampleFormat::Float => {
217             let iter = wav
218                 .into_samples::<f32>()
219                 .map(|s| s.map(|s| s * 32767.0).map_err(|e| e.into()));
220 
221             let read_sample = IterReadSample::new(iter, channels);
222             if sample_rate != 48_000.0 {
223                 Box::new(read_sample.resampled(sample_rate / 48_000.0))
224             } else {
225                 Box::new(read_sample)
226             }
227         }
228     }
229 }
230 
main() -> Result<(), Box<dyn std::error::Error>>231 fn main() -> Result<(), Box<dyn std::error::Error>> {
232     let matches = App::new("nnnoiseless")
233         .version(crate_version!())
234         .about("Remove noise from audio files")
235         .arg(
236             Arg::with_name("INPUT")
237                 .help("input audio file")
238                 .required(true),
239         )
240         .arg(
241             Arg::with_name("OUTPUT")
242                 .help("output audio file")
243                 .required(true),
244         )
245         .arg(Arg::with_name("wav-in").long("wav-in").help(
246             "if set, the input is a wav file (default is to detect wav files by their filename)",
247         ))
248         .arg(Arg::with_name("wav-out").long("wav-out").help(
249             "if set, the output is a wav file (default is to detect wav files by their filename)",
250         ))
251         .arg(
252             Arg::with_name("sample-rate")
253                 .long("sample-rate")
254                 .help("for raw input, the sample rate of the input (defaults to 48kHz)")
255                 .takes_value(true),
256         )
257         .arg(
258             Arg::with_name("channels")
259                 .long("channels")
260                 .help("for raw input, the number of channels (defaults to 1)")
261                 .takes_value(true),
262         )
263         .arg(
264             Arg::with_name("model")
265                 .long("model")
266                 .help("path to a custom model file")
267                 .takes_value(true),
268         )
269         .get_matches();
270 
271     let in_name = matches.value_of("INPUT").unwrap();
272     let out_name = matches.value_of("OUTPUT").unwrap();
273     let in_file = BufReader::new(
274         File::open(in_name)
275             .with_context(|| format!("Failed to open input file \"{}\"", in_name))?,
276     );
277     let out_file = BufWriter::new(
278         File::create(out_name)
279             .with_context(|| format!("Failed to open output file \"{}\"", out_name))?,
280     );
281     let in_wav =
282         matches.is_present("wav-in") || Path::new(in_name).extension() == Some("wav".as_ref());
283     let out_wav =
284         matches.is_present("wav-out") || Path::new(out_name).extension() == Some("wav".as_ref());
285 
286     let (mut samples, channels) = if in_wav {
287         let wav_reader = WavReader::new(in_file)?;
288         let channels = wav_reader.spec().channels;
289         (wav_samples(wav_reader), channels)
290     } else {
291         // TODO: report parse errors
292         let sample_rate = matches
293             .value_of("sample-rate")
294             .and_then(|s| f64::from_str(s).ok())
295             .unwrap_or(48_000.0);
296         let channels = matches
297             .value_of("channels")
298             .and_then(|s| u16::from_str(s).ok())
299             .unwrap_or(1);
300         (
301             raw_samples(in_file, channels as usize, sample_rate),
302             channels,
303         )
304     };
305 
306     let mut frame_writer: Box<dyn FrameWriter> = if out_wav {
307         let spec = WavSpec {
308             channels,
309             sample_rate: 48_000,
310             bits_per_sample: 16,
311             sample_format: SampleFormat::Int,
312         };
313         let writer = WavWriter::new(out_file, spec)?;
314         Box::new(WavFrameWriter { writer })
315     } else {
316         Box::new(RawFrameWriter {
317             writer: out_file,
318             buf: vec![0; FRAME_SIZE * 2],
319         })
320     };
321 
322     let model = if let Some(model_path) = matches.value_of("model") {
323         RnnModel::from_read(BufReader::new(
324             File::open(model_path).context("Failed to open model file")?,
325         ))
326         .context("Failed to read model file")?
327     } else {
328         RnnModel::default()
329     };
330     let channels = channels as usize;
331     let mut in_bufs = vec![vec![0.0; FRAME_SIZE]; channels];
332     let mut out_bufs = vec![vec![0.0; FRAME_SIZE]; channels];
333     let mut out_buf = vec![0.0; FRAME_SIZE * channels];
334     let mut states = vec![DenoiseState::with_model(&model); channels];
335     let mut first = true;
336     'outer: loop {
337         for i in 0..FRAME_SIZE {
338             if let Some(buf) = samples.next_sample()? {
339                 for j in 0..channels {
340                     in_bufs[j][i] = buf[j];
341                 }
342             } else {
343                 break 'outer;
344             }
345         }
346 
347         for j in 0..channels {
348             states[j].process_frame(&mut out_bufs[j], &in_bufs[j]);
349         }
350         if !first {
351             for i in 0..FRAME_SIZE {
352                 for j in 0..channels {
353                     out_buf[i * channels + j] = out_bufs[j][i];
354                 }
355             }
356             frame_writer.write_frame(&out_buf[..])?;
357         }
358         first = false;
359     }
360     frame_writer.finalize()?;
361 
362     Ok(())
363 }
364