diff --git a/main/audio/processors/afe_audio_processor.cc b/main/audio/processors/afe_audio_processor.cc index 0bbacbb8..15e9cc54 100644 --- a/main/audio/processors/afe_audio_processor.cc +++ b/main/audio/processors/afe_audio_processor.cc @@ -89,11 +89,15 @@ size_t AfeAudioProcessor::GetFeedSize() { } void AfeAudioProcessor::Feed(std::vector&& data) { - if (afe_data_ == nullptr || !IsRunning()) { + if (afe_data_ == nullptr) { return; } std::lock_guard lock(input_buffer_mutex_); + // Check running state inside lock to avoid TOCTOU race with Stop() + if (!IsRunning()) { + return; + } input_buffer_.insert(input_buffer_.end(), data.begin(), data.end()); size_t chunk_size = afe_iface_->get_feed_chunksize(afe_data_) * codec_->input_channels(); while (input_buffer_.size() >= chunk_size) { @@ -108,11 +112,11 @@ void AfeAudioProcessor::Start() { void AfeAudioProcessor::Stop() { xEventGroupClearBits(event_group_, PROCESSOR_RUNNING); + + std::lock_guard lock(input_buffer_mutex_); if (afe_data_ != nullptr) { afe_iface_->reset_buffer(afe_data_); } - - std::lock_guard lock(input_buffer_mutex_); input_buffer_.clear(); } diff --git a/main/audio/processors/no_audio_processor.h b/main/audio/processors/no_audio_processor.h index d326d505..13e4897d 100644 --- a/main/audio/processors/no_audio_processor.h +++ b/main/audio/processors/no_audio_processor.h @@ -3,6 +3,7 @@ #include #include +#include #include "audio_processor.h" #include "audio_codec.h" @@ -27,7 +28,7 @@ private: int frame_samples_ = 0; std::function&& data)> output_callback_; std::function vad_state_change_callback_; - bool is_running_ = false; + std::atomic is_running_ = false; }; #endif \ No newline at end of file diff --git a/main/audio/wake_words/afe_wake_word.cc b/main/audio/wake_words/afe_wake_word.cc index 1b09ca39..8fc5fefb 100644 --- a/main/audio/wake_words/afe_wake_word.cc +++ b/main/audio/wake_words/afe_wake_word.cc @@ -99,20 +99,24 @@ void AfeWakeWord::Start() { void AfeWakeWord::Stop() { xEventGroupClearBits(event_group_, DETECTION_RUNNING_EVENT); + + std::lock_guard lock(input_buffer_mutex_); if (afe_data_ != nullptr) { afe_iface_->reset_buffer(afe_data_); } - - std::lock_guard lock(input_buffer_mutex_); input_buffer_.clear(); } void AfeWakeWord::Feed(const std::vector& data) { - if (afe_data_ == nullptr || !(xEventGroupGetBits(event_group_) & DETECTION_RUNNING_EVENT)) { + if (afe_data_ == nullptr) { return; } std::lock_guard lock(input_buffer_mutex_); + // Check running state inside lock to avoid TOCTOU race with Stop() + if (!(xEventGroupGetBits(event_group_) & DETECTION_RUNNING_EVENT)) { + return; + } input_buffer_.insert(input_buffer_.end(), data.begin(), data.end()); size_t chunk_size = afe_iface_->get_feed_chunksize(afe_data_) * codec_->input_channels(); while (input_buffer_.size() >= chunk_size) { diff --git a/main/audio/wake_words/custom_wake_word.cc b/main/audio/wake_words/custom_wake_word.cc index 70095742..982bbe96 100644 --- a/main/audio/wake_words/custom_wake_word.cc +++ b/main/audio/wake_words/custom_wake_word.cc @@ -138,11 +138,19 @@ void CustomWakeWord::Start() { void CustomWakeWord::Stop() { running_ = false; + + std::lock_guard lock(input_buffer_mutex_); input_buffer_.clear(); } void CustomWakeWord::Feed(const std::vector& data) { - if (multinet_model_data_ == nullptr || !running_) { + if (multinet_model_data_ == nullptr) { + return; + } + + std::lock_guard lock(input_buffer_mutex_); + // Check running state inside lock to avoid TOCTOU race with Stop() + if (!running_) { return; } diff --git a/main/audio/wake_words/custom_wake_word.h b/main/audio/wake_words/custom_wake_word.h index 03043a5a..645ad1b3 100644 --- a/main/audio/wake_words/custom_wake_word.h +++ b/main/audio/wake_words/custom_wake_word.h @@ -54,6 +54,7 @@ private: std::string last_detected_wake_word_; std::atomic running_ = false; std::vector input_buffer_; + std::mutex input_buffer_mutex_; TaskHandle_t wake_word_encode_task_ = nullptr; StaticTask_t* wake_word_encode_task_buffer_ = nullptr; diff --git a/main/audio/wake_words/esp_wake_word.cc b/main/audio/wake_words/esp_wake_word.cc index c40d8f0a..930aa2d6 100644 --- a/main/audio/wake_words/esp_wake_word.cc +++ b/main/audio/wake_words/esp_wake_word.cc @@ -54,11 +54,19 @@ void EspWakeWord::Start() { void EspWakeWord::Stop() { running_ = false; + + std::lock_guard lock(input_buffer_mutex_); input_buffer_.clear(); } void EspWakeWord::Feed(const std::vector& data) { - if (wakenet_data_ == nullptr || !running_) { + if (wakenet_data_ == nullptr) { + return; + } + + std::lock_guard lock(input_buffer_mutex_); + // Check running state inside lock to avoid TOCTOU race with Stop() + if (!running_) { return; } diff --git a/main/audio/wake_words/esp_wake_word.h b/main/audio/wake_words/esp_wake_word.h index 95f96ad8..87e792da 100644 --- a/main/audio/wake_words/esp_wake_word.h +++ b/main/audio/wake_words/esp_wake_word.h @@ -9,6 +9,7 @@ #include #include #include +#include #include "audio_codec.h" #include "wake_word.h" @@ -38,6 +39,7 @@ private: std::function wake_word_detected_callback_; std::string last_detected_wake_word_; std::vector input_buffer_; + std::mutex input_buffer_mutex_; }; #endif