1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #include "modules/audio_device/include/test_audio_device.h"
11 
12 #include <algorithm>
13 #include <cstdint>
14 #include <cstdlib>
15 #include <memory>
16 #include <string>
17 #include <type_traits>
18 #include <utility>
19 #include <vector>
20 
21 #include "api/array_view.h"
22 #include "common_audio/wav_file.h"
23 #include "modules/audio_device/include/audio_device_default.h"
24 #include "rtc_base/buffer.h"
25 #include "rtc_base/checks.h"
26 #include "rtc_base/event.h"
27 #include "rtc_base/logging.h"
28 #include "rtc_base/numerics/safe_conversions.h"
29 #include "rtc_base/platform_thread.h"
30 #include "rtc_base/random.h"
31 #include "rtc_base/ref_counted_object.h"
32 #include "rtc_base/synchronization/mutex.h"
33 #include "rtc_base/task_queue.h"
34 #include "rtc_base/task_utils/repeating_task.h"
35 #include "rtc_base/thread_annotations.h"
36 #include "rtc_base/time_utils.h"
37 
38 namespace webrtc {
39 
40 namespace {
41 
42 constexpr int kFrameLengthUs = 10000;
43 constexpr int kFramesPerSecond = rtc::kNumMicrosecsPerSec / kFrameLengthUs;
44 
45 // TestAudioDeviceModule implements an AudioDevice module that can act both as a
46 // capturer and a renderer. It will use 10ms audio frames.
47 class TestAudioDeviceModuleImpl
48     : public webrtc_impl::AudioDeviceModuleDefault<TestAudioDeviceModule> {
49  public:
50   // Creates a new TestAudioDeviceModule. When capturing or playing, 10 ms audio
51   // frames will be processed every 10ms / |speed|.
52   // |capturer| is an object that produces audio data. Can be nullptr if this
53   // device is never used for recording.
54   // |renderer| is an object that receives audio data that would have been
55   // played out. Can be nullptr if this device is never used for playing.
56   // Use one of the Create... functions to get these instances.
TestAudioDeviceModuleImpl(TaskQueueFactory * task_queue_factory,std::unique_ptr<Capturer> capturer,std::unique_ptr<Renderer> renderer,float speed=1)57   TestAudioDeviceModuleImpl(TaskQueueFactory* task_queue_factory,
58                             std::unique_ptr<Capturer> capturer,
59                             std::unique_ptr<Renderer> renderer,
60                             float speed = 1)
61       : task_queue_factory_(task_queue_factory),
62         capturer_(std::move(capturer)),
63         renderer_(std::move(renderer)),
64         process_interval_us_(kFrameLengthUs / speed),
65         audio_callback_(nullptr),
66         rendering_(false),
67         capturing_(false) {
68     auto good_sample_rate = [](int sr) {
69       return sr == 8000 || sr == 16000 || sr == 32000 || sr == 44100 ||
70              sr == 48000;
71     };
72 
73     if (renderer_) {
74       const int sample_rate = renderer_->SamplingFrequency();
75       playout_buffer_.resize(
76           SamplesPerFrame(sample_rate) * renderer_->NumChannels(), 0);
77       RTC_CHECK(good_sample_rate(sample_rate));
78     }
79     if (capturer_) {
80       RTC_CHECK(good_sample_rate(capturer_->SamplingFrequency()));
81     }
82   }
83 
~TestAudioDeviceModuleImpl()84   ~TestAudioDeviceModuleImpl() override {
85     StopPlayout();
86     StopRecording();
87   }
88 
Init()89   int32_t Init() override {
90     task_queue_ =
91         std::make_unique<rtc::TaskQueue>(task_queue_factory_->CreateTaskQueue(
92             "TestAudioDeviceModuleImpl", TaskQueueFactory::Priority::NORMAL));
93 
94     RepeatingTaskHandle::Start(task_queue_->Get(), [this]() {
95       ProcessAudio();
96       return TimeDelta::Micros(process_interval_us_);
97     });
98     return 0;
99   }
100 
RegisterAudioCallback(AudioTransport * callback)101   int32_t RegisterAudioCallback(AudioTransport* callback) override {
102     MutexLock lock(&lock_);
103     RTC_DCHECK(callback || audio_callback_);
104     audio_callback_ = callback;
105     return 0;
106   }
107 
StartPlayout()108   int32_t StartPlayout() override {
109     MutexLock lock(&lock_);
110     RTC_CHECK(renderer_);
111     rendering_ = true;
112     return 0;
113   }
114 
StopPlayout()115   int32_t StopPlayout() override {
116     MutexLock lock(&lock_);
117     rendering_ = false;
118     return 0;
119   }
120 
StartRecording()121   int32_t StartRecording() override {
122     MutexLock lock(&lock_);
123     RTC_CHECK(capturer_);
124     capturing_ = true;
125     return 0;
126   }
127 
StopRecording()128   int32_t StopRecording() override {
129     MutexLock lock(&lock_);
130     capturing_ = false;
131     return 0;
132   }
133 
Playing() const134   bool Playing() const override {
135     MutexLock lock(&lock_);
136     return rendering_;
137   }
138 
Recording() const139   bool Recording() const override {
140     MutexLock lock(&lock_);
141     return capturing_;
142   }
143 
144   // Blocks until the Renderer refuses to receive data.
145   // Returns false if |timeout_ms| passes before that happens.
WaitForPlayoutEnd(int timeout_ms=rtc::Event::kForever)146   bool WaitForPlayoutEnd(int timeout_ms = rtc::Event::kForever) override {
147     return done_rendering_.Wait(timeout_ms);
148   }
149 
150   // Blocks until the Recorder stops producing data.
151   // Returns false if |timeout_ms| passes before that happens.
WaitForRecordingEnd(int timeout_ms=rtc::Event::kForever)152   bool WaitForRecordingEnd(int timeout_ms = rtc::Event::kForever) override {
153     return done_capturing_.Wait(timeout_ms);
154   }
155 
156  private:
ProcessAudio()157   void ProcessAudio() {
158     MutexLock lock(&lock_);
159     if (capturing_) {
160       // Capture 10ms of audio. 2 bytes per sample.
161       const bool keep_capturing = capturer_->Capture(&recording_buffer_);
162       uint32_t new_mic_level = 0;
163       if (recording_buffer_.size() > 0) {
164         audio_callback_->RecordedDataIsAvailable(
165             recording_buffer_.data(),
166             recording_buffer_.size() / capturer_->NumChannels(),
167             2 * capturer_->NumChannels(), capturer_->NumChannels(),
168             capturer_->SamplingFrequency(), 0, 0, 0, false, new_mic_level);
169       }
170       if (!keep_capturing) {
171         capturing_ = false;
172         done_capturing_.Set();
173       }
174     }
175     if (rendering_) {
176       size_t samples_out = 0;
177       int64_t elapsed_time_ms = -1;
178       int64_t ntp_time_ms = -1;
179       const int sampling_frequency = renderer_->SamplingFrequency();
180       audio_callback_->NeedMorePlayData(
181           SamplesPerFrame(sampling_frequency), 2 * renderer_->NumChannels(),
182           renderer_->NumChannels(), sampling_frequency, playout_buffer_.data(),
183           samples_out, &elapsed_time_ms, &ntp_time_ms);
184       const bool keep_rendering = renderer_->Render(
185           rtc::ArrayView<const int16_t>(playout_buffer_.data(), samples_out));
186       if (!keep_rendering) {
187         rendering_ = false;
188         done_rendering_.Set();
189       }
190     }
191   }
192   TaskQueueFactory* const task_queue_factory_;
193   const std::unique_ptr<Capturer> capturer_ RTC_GUARDED_BY(lock_);
194   const std::unique_ptr<Renderer> renderer_ RTC_GUARDED_BY(lock_);
195   const int64_t process_interval_us_;
196 
197   mutable Mutex lock_;
198   AudioTransport* audio_callback_ RTC_GUARDED_BY(lock_);
199   bool rendering_ RTC_GUARDED_BY(lock_);
200   bool capturing_ RTC_GUARDED_BY(lock_);
201   rtc::Event done_rendering_;
202   rtc::Event done_capturing_;
203 
204   std::vector<int16_t> playout_buffer_ RTC_GUARDED_BY(lock_);
205   rtc::BufferT<int16_t> recording_buffer_ RTC_GUARDED_BY(lock_);
206   std::unique_ptr<rtc::TaskQueue> task_queue_;
207 };
208 
209 // A fake capturer that generates pulses with random samples between
210 // -max_amplitude and +max_amplitude.
211 class PulsedNoiseCapturerImpl final
212     : public TestAudioDeviceModule::PulsedNoiseCapturer {
213  public:
214   // Assuming 10ms audio packets.
PulsedNoiseCapturerImpl(int16_t max_amplitude,int sampling_frequency_in_hz,int num_channels)215   PulsedNoiseCapturerImpl(int16_t max_amplitude,
216                           int sampling_frequency_in_hz,
217                           int num_channels)
218       : sampling_frequency_in_hz_(sampling_frequency_in_hz),
219         fill_with_zero_(false),
220         random_generator_(1),
221         max_amplitude_(max_amplitude),
222         num_channels_(num_channels) {
223     RTC_DCHECK_GT(max_amplitude, 0);
224   }
225 
SamplingFrequency() const226   int SamplingFrequency() const override { return sampling_frequency_in_hz_; }
227 
NumChannels() const228   int NumChannels() const override { return num_channels_; }
229 
Capture(rtc::BufferT<int16_t> * buffer)230   bool Capture(rtc::BufferT<int16_t>* buffer) override {
231     fill_with_zero_ = !fill_with_zero_;
232     int16_t max_amplitude;
233     {
234       MutexLock lock(&lock_);
235       max_amplitude = max_amplitude_;
236     }
237     buffer->SetData(
238         TestAudioDeviceModule::SamplesPerFrame(sampling_frequency_in_hz_) *
239             num_channels_,
240         [&](rtc::ArrayView<int16_t> data) {
241           if (fill_with_zero_) {
242             std::fill(data.begin(), data.end(), 0);
243           } else {
244             std::generate(data.begin(), data.end(), [&]() {
245               return random_generator_.Rand(-max_amplitude, max_amplitude);
246             });
247           }
248           return data.size();
249         });
250     return true;
251   }
252 
SetMaxAmplitude(int16_t amplitude)253   void SetMaxAmplitude(int16_t amplitude) override {
254     MutexLock lock(&lock_);
255     max_amplitude_ = amplitude;
256   }
257 
258  private:
259   int sampling_frequency_in_hz_;
260   bool fill_with_zero_;
261   Random random_generator_;
262   Mutex lock_;
263   int16_t max_amplitude_ RTC_GUARDED_BY(lock_);
264   const int num_channels_;
265 };
266 
267 class WavFileReader final : public TestAudioDeviceModule::Capturer {
268  public:
WavFileReader(std::string filename,int sampling_frequency_in_hz,int num_channels,bool repeat)269   WavFileReader(std::string filename,
270                 int sampling_frequency_in_hz,
271                 int num_channels,
272                 bool repeat)
273       : WavFileReader(std::make_unique<WavReader>(filename),
274                       sampling_frequency_in_hz,
275                       num_channels,
276                       repeat) {}
277 
SamplingFrequency() const278   int SamplingFrequency() const override { return sampling_frequency_in_hz_; }
279 
NumChannels() const280   int NumChannels() const override { return num_channels_; }
281 
Capture(rtc::BufferT<int16_t> * buffer)282   bool Capture(rtc::BufferT<int16_t>* buffer) override {
283     buffer->SetData(
284         TestAudioDeviceModule::SamplesPerFrame(sampling_frequency_in_hz_) *
285             num_channels_,
286         [&](rtc::ArrayView<int16_t> data) {
287           size_t read = wav_reader_->ReadSamples(data.size(), data.data());
288           if (read < data.size() && repeat_) {
289             do {
290               wav_reader_->Reset();
291               size_t delta = wav_reader_->ReadSamples(
292                   data.size() - read, data.subview(read).data());
293               RTC_CHECK_GT(delta, 0) << "No new data read from file";
294               read += delta;
295             } while (read < data.size());
296           }
297           return read;
298         });
299     return buffer->size() > 0;
300   }
301 
302  private:
WavFileReader(std::unique_ptr<WavReader> wav_reader,int sampling_frequency_in_hz,int num_channels,bool repeat)303   WavFileReader(std::unique_ptr<WavReader> wav_reader,
304                 int sampling_frequency_in_hz,
305                 int num_channels,
306                 bool repeat)
307       : sampling_frequency_in_hz_(sampling_frequency_in_hz),
308         num_channels_(num_channels),
309         wav_reader_(std::move(wav_reader)),
310         repeat_(repeat) {
311     RTC_CHECK_EQ(wav_reader_->sample_rate(), sampling_frequency_in_hz);
312     RTC_CHECK_EQ(wav_reader_->num_channels(), num_channels);
313   }
314 
315   const int sampling_frequency_in_hz_;
316   const int num_channels_;
317   std::unique_ptr<WavReader> wav_reader_;
318   const bool repeat_;
319 };
320 
321 class WavFileWriter final : public TestAudioDeviceModule::Renderer {
322  public:
WavFileWriter(std::string filename,int sampling_frequency_in_hz,int num_channels)323   WavFileWriter(std::string filename,
324                 int sampling_frequency_in_hz,
325                 int num_channels)
326       : WavFileWriter(std::make_unique<WavWriter>(filename,
327                                                   sampling_frequency_in_hz,
328                                                   num_channels),
329                       sampling_frequency_in_hz,
330                       num_channels) {}
331 
SamplingFrequency() const332   int SamplingFrequency() const override { return sampling_frequency_in_hz_; }
333 
NumChannels() const334   int NumChannels() const override { return num_channels_; }
335 
Render(rtc::ArrayView<const int16_t> data)336   bool Render(rtc::ArrayView<const int16_t> data) override {
337     wav_writer_->WriteSamples(data.data(), data.size());
338     return true;
339   }
340 
341  private:
WavFileWriter(std::unique_ptr<WavWriter> wav_writer,int sampling_frequency_in_hz,int num_channels)342   WavFileWriter(std::unique_ptr<WavWriter> wav_writer,
343                 int sampling_frequency_in_hz,
344                 int num_channels)
345       : sampling_frequency_in_hz_(sampling_frequency_in_hz),
346         wav_writer_(std::move(wav_writer)),
347         num_channels_(num_channels) {}
348 
349   int sampling_frequency_in_hz_;
350   std::unique_ptr<WavWriter> wav_writer_;
351   const int num_channels_;
352 };
353 
354 class BoundedWavFileWriter : public TestAudioDeviceModule::Renderer {
355  public:
BoundedWavFileWriter(std::string filename,int sampling_frequency_in_hz,int num_channels)356   BoundedWavFileWriter(std::string filename,
357                        int sampling_frequency_in_hz,
358                        int num_channels)
359       : sampling_frequency_in_hz_(sampling_frequency_in_hz),
360         wav_writer_(filename, sampling_frequency_in_hz, num_channels),
361         num_channels_(num_channels),
362         silent_audio_(
363             TestAudioDeviceModule::SamplesPerFrame(sampling_frequency_in_hz) *
364                 num_channels,
365             0),
366         started_writing_(false),
367         trailing_zeros_(0) {}
368 
SamplingFrequency() const369   int SamplingFrequency() const override { return sampling_frequency_in_hz_; }
370 
NumChannels() const371   int NumChannels() const override { return num_channels_; }
372 
Render(rtc::ArrayView<const int16_t> data)373   bool Render(rtc::ArrayView<const int16_t> data) override {
374     const int16_t kAmplitudeThreshold = 5;
375 
376     const int16_t* begin = data.begin();
377     const int16_t* end = data.end();
378     if (!started_writing_) {
379       // Cut off silence at the beginning.
380       while (begin < end) {
381         if (std::abs(*begin) > kAmplitudeThreshold) {
382           started_writing_ = true;
383           break;
384         }
385         ++begin;
386       }
387     }
388     if (started_writing_) {
389       // Cut off silence at the end.
390       while (begin < end) {
391         if (*(end - 1) != 0) {
392           break;
393         }
394         --end;
395       }
396       if (begin < end) {
397         // If it turns out that the silence was not final, need to write all the
398         // skipped zeros and continue writing audio.
399         while (trailing_zeros_ > 0) {
400           const size_t zeros_to_write =
401               std::min(trailing_zeros_, silent_audio_.size());
402           wav_writer_.WriteSamples(silent_audio_.data(), zeros_to_write);
403           trailing_zeros_ -= zeros_to_write;
404         }
405         wav_writer_.WriteSamples(begin, end - begin);
406       }
407       // Save the number of zeros we skipped in case this needs to be restored.
408       trailing_zeros_ += data.end() - end;
409     }
410     return true;
411   }
412 
413  private:
414   int sampling_frequency_in_hz_;
415   WavWriter wav_writer_;
416   const int num_channels_;
417   std::vector<int16_t> silent_audio_;
418   bool started_writing_;
419   size_t trailing_zeros_;
420 };
421 
422 class DiscardRenderer final : public TestAudioDeviceModule::Renderer {
423  public:
DiscardRenderer(int sampling_frequency_in_hz,int num_channels)424   explicit DiscardRenderer(int sampling_frequency_in_hz, int num_channels)
425       : sampling_frequency_in_hz_(sampling_frequency_in_hz),
426         num_channels_(num_channels) {}
427 
SamplingFrequency() const428   int SamplingFrequency() const override { return sampling_frequency_in_hz_; }
429 
NumChannels() const430   int NumChannels() const override { return num_channels_; }
431 
Render(rtc::ArrayView<const int16_t> data)432   bool Render(rtc::ArrayView<const int16_t> data) override { return true; }
433 
434  private:
435   int sampling_frequency_in_hz_;
436   const int num_channels_;
437 };
438 
439 }  // namespace
440 
SamplesPerFrame(int sampling_frequency_in_hz)441 size_t TestAudioDeviceModule::SamplesPerFrame(int sampling_frequency_in_hz) {
442   return rtc::CheckedDivExact(sampling_frequency_in_hz, kFramesPerSecond);
443 }
444 
Create(TaskQueueFactory * task_queue_factory,std::unique_ptr<TestAudioDeviceModule::Capturer> capturer,std::unique_ptr<TestAudioDeviceModule::Renderer> renderer,float speed)445 rtc::scoped_refptr<TestAudioDeviceModule> TestAudioDeviceModule::Create(
446     TaskQueueFactory* task_queue_factory,
447     std::unique_ptr<TestAudioDeviceModule::Capturer> capturer,
448     std::unique_ptr<TestAudioDeviceModule::Renderer> renderer,
449     float speed) {
450   return new rtc::RefCountedObject<TestAudioDeviceModuleImpl>(
451       task_queue_factory, std::move(capturer), std::move(renderer), speed);
452 }
453 
454 std::unique_ptr<TestAudioDeviceModule::PulsedNoiseCapturer>
CreatePulsedNoiseCapturer(int16_t max_amplitude,int sampling_frequency_in_hz,int num_channels)455 TestAudioDeviceModule::CreatePulsedNoiseCapturer(int16_t max_amplitude,
456                                                  int sampling_frequency_in_hz,
457                                                  int num_channels) {
458   return std::make_unique<PulsedNoiseCapturerImpl>(
459       max_amplitude, sampling_frequency_in_hz, num_channels);
460 }
461 
462 std::unique_ptr<TestAudioDeviceModule::Renderer>
CreateDiscardRenderer(int sampling_frequency_in_hz,int num_channels)463 TestAudioDeviceModule::CreateDiscardRenderer(int sampling_frequency_in_hz,
464                                              int num_channels) {
465   return std::make_unique<DiscardRenderer>(sampling_frequency_in_hz,
466                                            num_channels);
467 }
468 
469 std::unique_ptr<TestAudioDeviceModule::Capturer>
CreateWavFileReader(std::string filename,int sampling_frequency_in_hz,int num_channels)470 TestAudioDeviceModule::CreateWavFileReader(std::string filename,
471                                            int sampling_frequency_in_hz,
472                                            int num_channels) {
473   return std::make_unique<WavFileReader>(filename, sampling_frequency_in_hz,
474                                          num_channels, false);
475 }
476 
477 std::unique_ptr<TestAudioDeviceModule::Capturer>
CreateWavFileReader(std::string filename,bool repeat)478 TestAudioDeviceModule::CreateWavFileReader(std::string filename, bool repeat) {
479   WavReader reader(filename);
480   int sampling_frequency_in_hz = reader.sample_rate();
481   int num_channels = rtc::checked_cast<int>(reader.num_channels());
482   return std::make_unique<WavFileReader>(filename, sampling_frequency_in_hz,
483                                          num_channels, repeat);
484 }
485 
486 std::unique_ptr<TestAudioDeviceModule::Renderer>
CreateWavFileWriter(std::string filename,int sampling_frequency_in_hz,int num_channels)487 TestAudioDeviceModule::CreateWavFileWriter(std::string filename,
488                                            int sampling_frequency_in_hz,
489                                            int num_channels) {
490   return std::make_unique<WavFileWriter>(filename, sampling_frequency_in_hz,
491                                          num_channels);
492 }
493 
494 std::unique_ptr<TestAudioDeviceModule::Renderer>
CreateBoundedWavFileWriter(std::string filename,int sampling_frequency_in_hz,int num_channels)495 TestAudioDeviceModule::CreateBoundedWavFileWriter(std::string filename,
496                                                   int sampling_frequency_in_hz,
497                                                   int num_channels) {
498   return std::make_unique<BoundedWavFileWriter>(
499       filename, sampling_frequency_in_hz, num_channels);
500 }
501 
502 }  // namespace webrtc
503