1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_
12 #define MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_
13 
14 #include <array>
15 #include <memory>
16 #include <vector>
17 
18 #include "api/optional.h"
19 #include "modules/audio_processing/aec3/aec3_common.h"
20 #include "modules/audio_processing/aec3/downsampled_render_buffer.h"
21 #include "rtc_base/constructormagic.h"
22 
23 namespace webrtc {
24 namespace aec3 {
25 
26 #if defined(WEBRTC_HAS_NEON)
27 
28 // Filter core for the matched filter that is optimized for NEON.
29 void MatchedFilterCore_NEON(size_t x_start_index,
30                             float x2_sum_threshold,
31                             rtc::ArrayView<const float> x,
32                             rtc::ArrayView<const float> y,
33                             rtc::ArrayView<float> h,
34                             bool* filters_updated,
35                             float* error_sum);
36 
37 #endif
38 
39 #if defined(WEBRTC_ARCH_X86_FAMILY)
40 
41 // Filter core for the matched filter that is optimized for SSE2.
42 void MatchedFilterCore_SSE2(size_t x_start_index,
43                             float x2_sum_threshold,
44                             rtc::ArrayView<const float> x,
45                             rtc::ArrayView<const float> y,
46                             rtc::ArrayView<float> h,
47                             bool* filters_updated,
48                             float* error_sum);
49 
50 #endif
51 
52 // Filter core for the matched filter.
53 void MatchedFilterCore(size_t x_start_index,
54                        float x2_sum_threshold,
55                        rtc::ArrayView<const float> x,
56                        rtc::ArrayView<const float> y,
57                        rtc::ArrayView<float> h,
58                        bool* filters_updated,
59                        float* error_sum);
60 
61 }  // namespace aec3
62 
63 class ApmDataDumper;
64 
65 // Produces recursively updated cross-correlation estimates for several signal
66 // shifts where the intra-shift spacing is uniform.
67 class MatchedFilter {
68  public:
69   // Stores properties for the lag estimate corresponding to a particular signal
70   // shift.
71   struct LagEstimate {
72     LagEstimate() = default;
LagEstimateLagEstimate73     LagEstimate(float accuracy, bool reliable, size_t lag, bool updated)
74         : accuracy(accuracy), reliable(reliable), lag(lag), updated(updated) {}
75 
76     float accuracy = 0.f;
77     bool reliable = false;
78     size_t lag = 0;
79     bool updated = false;
80   };
81 
82   MatchedFilter(ApmDataDumper* data_dumper,
83                 Aec3Optimization optimization,
84                 size_t sub_block_size,
85                 size_t window_size_sub_blocks,
86                 int num_matched_filters,
87                 size_t alignment_shift_sub_blocks,
88                 float excitation_limit);
89 
90   ~MatchedFilter();
91 
92   // Updates the correlation with the values in the capture buffer.
93   void Update(const DownsampledRenderBuffer& render_buffer,
94               rtc::ArrayView<const float> capture);
95 
96   // Resets the matched filter.
97   void Reset();
98 
99   // Returns the current lag estimates.
GetLagEstimates()100   rtc::ArrayView<const MatchedFilter::LagEstimate> GetLagEstimates() const {
101     return lag_estimates_;
102   }
103 
104   // Returns the maximum filter lag.
GetMaxFilterLag()105   size_t GetMaxFilterLag() const {
106     return filters_.size() * filter_intra_lag_shift_ + filters_[0].size();
107   }
108 
109   // Log matched filter properties.
110   void LogFilterProperties(int sample_rate_hz,
111                            size_t shift,
112                            size_t downsampling_factor) const;
113 
114  private:
115   ApmDataDumper* const data_dumper_;
116   const Aec3Optimization optimization_;
117   const size_t sub_block_size_;
118   const size_t filter_intra_lag_shift_;
119   std::vector<std::vector<float>> filters_;
120   std::vector<LagEstimate> lag_estimates_;
121   std::vector<size_t> filters_offsets_;
122   const float excitation_limit_;
123 
124   RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(MatchedFilter);
125 };
126 
127 }  // namespace webrtc
128 
129 #endif  // MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_
130