1 /* 2 * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_ 12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_ 13 14 #include <array> 15 #include <vector> 16 17 #include "absl/strings/string_view.h" 18 #include "api/array_view.h" 19 #include "api/function_view.h" 20 #include "modules/audio_processing/agc2/cpu_features.h" 21 #include "modules/audio_processing/agc2/rnn_vad/vector_math.h" 22 23 namespace webrtc { 24 namespace rnn_vad { 25 26 // Activation function for a neural network cell. 27 enum class ActivationFunction { kTansigApproximated, kSigmoidApproximated }; 28 29 // Maximum number of units for an FC layer. 30 constexpr int kFullyConnectedLayerMaxUnits = 24; 31 32 // Fully-connected layer with a custom activation function which owns the output 33 // buffer. 34 class FullyConnectedLayer { 35 public: 36 // Ctor. `output_size` cannot be greater than `kFullyConnectedLayerMaxUnits`. 37 FullyConnectedLayer(int input_size, 38 int output_size, 39 rtc::ArrayView<const int8_t> bias, 40 rtc::ArrayView<const int8_t> weights, 41 ActivationFunction activation_function, 42 const AvailableCpuFeatures& cpu_features, 43 absl::string_view layer_name); 44 FullyConnectedLayer(const FullyConnectedLayer&) = delete; 45 FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete; 46 ~FullyConnectedLayer(); 47 48 // Returns the size of the input vector. input_size()49 int input_size() const { return input_size_; } 50 // Returns the pointer to the first element of the output buffer. data()51 const float* data() const { return output_.data(); } 52 // Returns the size of the output buffer. size()53 int size() const { return output_size_; } 54 55 // Computes the fully-connected layer output. 56 void ComputeOutput(rtc::ArrayView<const float> input); 57 58 private: 59 const int input_size_; 60 const int output_size_; 61 const std::vector<float> bias_; 62 const std::vector<float> weights_; 63 const VectorMath vector_math_; 64 rtc::FunctionView<float(float)> activation_function_; 65 // Over-allocated array with size equal to `output_size_`. 66 std::array<float, kFullyConnectedLayerMaxUnits> output_; 67 }; 68 69 } // namespace rnn_vad 70 } // namespace webrtc 71 72 #endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_ 73