1 /*
2  *  Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #include "modules/audio_processing/aec3/cascaded_biquad_filter.h"
11 
12 #include <algorithm>
13 
14 #include "rtc_base/checks.h"
15 
16 namespace webrtc {
17 
BiQuadParam(std::complex<float> zero,std::complex<float> pole,float gain,bool mirror_zero_along_i_axis)18 CascadedBiQuadFilter::BiQuadParam::BiQuadParam(std::complex<float> zero,
19                                                std::complex<float> pole,
20                                                float gain,
21                                                bool mirror_zero_along_i_axis)
22     : zero(zero),
23       pole(pole),
24       gain(gain),
25       mirror_zero_along_i_axis(mirror_zero_along_i_axis) {}
26 
27 CascadedBiQuadFilter::BiQuadParam::BiQuadParam(const BiQuadParam&) = default;
28 
BiQuad(const CascadedBiQuadFilter::BiQuadParam & param)29 CascadedBiQuadFilter::BiQuad::BiQuad(
30     const CascadedBiQuadFilter::BiQuadParam& param)
31     : x(), y() {
32   float z_r = std::real(param.zero);
33   float z_i = std::imag(param.zero);
34   float p_r = std::real(param.pole);
35   float p_i = std::imag(param.pole);
36   float gain = param.gain;
37 
38   if (param.mirror_zero_along_i_axis) {
39     // Assuming zeroes at z_r and -z_r.
40     RTC_DCHECK(z_i == 0.f);
41     coefficients.b[0] = gain * 1.f;
42     coefficients.b[1] = 0.f;
43     coefficients.b[2] = gain * -(z_r * z_r);
44   } else {
45     // Assuming zeros at (z_r + z_i*i) and (z_r - z_i*i).
46     coefficients.b[0] = gain * 1.f;
47     coefficients.b[1] = gain * -2.f * z_r;
48     coefficients.b[2] = gain * (z_r * z_r + z_i * z_i);
49   }
50 
51   // Assuming poles at (p_r + p_i*i) and (p_r - p_i*i).
52   coefficients.a[0] = -2.f * p_r;
53   coefficients.a[1] = p_r * p_r + p_i * p_i;
54 }
55 
CascadedBiQuadFilter(const CascadedBiQuadFilter::BiQuadCoefficients & coefficients,size_t num_biquads)56 CascadedBiQuadFilter::CascadedBiQuadFilter(
57     const CascadedBiQuadFilter::BiQuadCoefficients& coefficients,
58     size_t num_biquads)
59     : biquads_(num_biquads, coefficients) {}
60 
CascadedBiQuadFilter(const std::vector<CascadedBiQuadFilter::BiQuadParam> & biquad_params)61 CascadedBiQuadFilter::CascadedBiQuadFilter(
62     const std::vector<CascadedBiQuadFilter::BiQuadParam>& biquad_params) {
63   for (const auto& param : biquad_params) {
64     biquads_.push_back(BiQuad(param));
65   }
66 }
67 
68 CascadedBiQuadFilter::~CascadedBiQuadFilter() = default;
69 
Process(rtc::ArrayView<const float> x,rtc::ArrayView<float> y)70 void CascadedBiQuadFilter::Process(rtc::ArrayView<const float> x,
71                                    rtc::ArrayView<float> y) {
72   if (biquads_.size() > 0) {
73     ApplyBiQuad(x, y, &biquads_[0]);
74     for (size_t k = 1; k < biquads_.size(); ++k) {
75       ApplyBiQuad(y, y, &biquads_[k]);
76     }
77   } else {
78     std::copy(x.begin(), x.end(), y.begin());
79   }
80 }
81 
Process(rtc::ArrayView<float> y)82 void CascadedBiQuadFilter::Process(rtc::ArrayView<float> y) {
83   for (auto& biquad : biquads_) {
84     ApplyBiQuad(y, y, &biquad);
85   }
86 }
87 
ApplyBiQuad(rtc::ArrayView<const float> x,rtc::ArrayView<float> y,CascadedBiQuadFilter::BiQuad * biquad)88 void CascadedBiQuadFilter::ApplyBiQuad(rtc::ArrayView<const float> x,
89                                        rtc::ArrayView<float> y,
90                                        CascadedBiQuadFilter::BiQuad* biquad) {
91   RTC_DCHECK_EQ(x.size(), y.size());
92   const auto* c_b = biquad->coefficients.b;
93   const auto* c_a = biquad->coefficients.a;
94   auto* m_x = biquad->x;
95   auto* m_y = biquad->y;
96   for (size_t k = 0; k < x.size(); ++k) {
97     const float tmp = x[k];
98     y[k] = c_b[0] * tmp + c_b[1] * m_x[0] + c_b[2] * m_x[1] - c_a[0] * m_y[0] -
99            c_a[1] * m_y[1];
100     m_x[1] = m_x[0];
101     m_x[0] = tmp;
102     m_y[1] = m_y[0];
103     m_y[0] = y[k];
104   }
105 }
106 
107 }  // namespace webrtc
108