diff --git a/main/Kconfig.projbuild b/main/Kconfig.projbuild index 47b23553..79995d3e 100644 --- a/main/Kconfig.projbuild +++ b/main/Kconfig.projbuild @@ -680,6 +680,16 @@ config SEND_WAKE_WORD_DATA help Send wake word data to the server as the first message of the conversation and wait for response +config WAKE_WORD_DETECTION_IN_LISTENING + bool "Enable Wake Word Detection in Listening Mode" + default n + depends on USE_AFE_WAKE_WORD || USE_CUSTOM_WAKE_WORD + help + Enable wake word detection while in listening mode. + When enabled, the device can detect wake word during listening, + which allows interrupting the current conversation. + When disabled (default), wake word detection is turned off during listening. + config USE_AUDIO_PROCESSOR bool "Enable Audio Noise Reduction" default y diff --git a/main/application.cc b/main/application.cc index 89d5eb62..9d544c70 100644 --- a/main/application.cc +++ b/main/application.cc @@ -691,7 +691,7 @@ void Application::HandleToggleChatEvent() { } if (state == kDeviceStateIdle) { - ListeningMode mode = aec_mode_ == kAecOff ? kListeningModeAutoStop : kListeningModeRealtime; + ListeningMode mode = GetDefaultListeningMode(); if (!protocol_->IsAudioChannelOpened()) { SetDeviceState(kDeviceStateConnecting); // Schedule to let the state change be processed first (UI update) @@ -777,7 +777,9 @@ void Application::HandleWakeWordDetectedEvent() { } auto state = GetDeviceState(); - + auto wake_word = audio_service_.GetLastWakeWord(); + ESP_LOGI(TAG, "Wake word detected: %s (state: %d)", wake_word.c_str(), (int)state); + if (state == kDeviceStateIdle) { audio_service_.EncodeWakeWord(); auto wake_word = audio_service_.GetLastWakeWord(); @@ -793,8 +795,22 @@ void Application::HandleWakeWordDetectedEvent() { } // Channel already opened, continue directly ContinueWakeWordInvoke(wake_word); - } else if (state == kDeviceStateSpeaking) { + } else if (state == kDeviceStateSpeaking || state == kDeviceStateListening) { AbortSpeaking(kAbortReasonWakeWordDetected); + // Clear send queue to avoid sending residues to server + while (audio_service_.PopPacketFromSendQueue()); + + if (state == kDeviceStateListening) { + protocol_->SendStartListening(GetDefaultListeningMode()); + audio_service_.ResetDecoder(); + audio_service_.PlaySound(Lang::Sounds::OGG_POPUP); + // Re-enable wake word detection as it was stopped by the detection itself + audio_service_.EnableWakeWordDetection(true); + } else { + // Play popup sound and start listening again + play_popup_on_listening_ = true; + SetListeningMode(GetDefaultListeningMode()); + } } else if (state == kDeviceStateActivating) { // Restart the activation check if the wake word is detected during activation SetDeviceState(kDeviceStateIdle); @@ -822,12 +838,15 @@ void Application::ContinueWakeWordInvoke(const std::string& wake_word) { } // Set the chat state to wake word detected protocol_->SendWakeWordDetected(wake_word); - SetListeningMode(aec_mode_ == kAecOff ? kListeningModeAutoStop : kListeningModeRealtime); + + // Set flag to play popup sound after state changes to listening + play_popup_on_listening_ = true; + SetListeningMode(GetDefaultListeningMode()); #else // Set flag to play popup sound after state changes to listening // (PlaySound here would be cleared by ResetDecoder in EnableVoiceProcessing) play_popup_on_listening_ = true; - SetListeningMode(aec_mode_ == kAecOff ? kListeningModeAutoStop : kListeningModeRealtime); + SetListeningMode(GetDefaultListeningMode()); #endif } @@ -859,7 +878,7 @@ void Application::HandleStateChangedEvent() { display->SetEmotion("neutral"); // Make sure the audio processor is running - if (!audio_service_.IsAudioProcessorRunning()) { + if (play_popup_on_listening_ || !audio_service_.IsAudioProcessorRunning()) { // For auto mode, wait for playback queue to be empty before enabling voice processing // This prevents audio truncation when STOP arrives late due to network jitter if (listening_mode_ == kListeningModeAutoStop) { @@ -869,9 +888,16 @@ void Application::HandleStateChangedEvent() { // Send the start listening command protocol_->SendStartListening(listening_mode_); audio_service_.EnableVoiceProcessing(true); - audio_service_.EnableWakeWordDetection(false); } +#ifdef CONFIG_WAKE_WORD_DETECTION_IN_LISTENING + // Enable wake word detection in listening mode (configured via Kconfig) + audio_service_.EnableWakeWordDetection(audio_service_.IsAfeWakeWord()); +#else + // Disable wake word detection in listening mode + audio_service_.EnableWakeWordDetection(false); +#endif + // Play popup sound after ResetDecoder (in EnableVoiceProcessing) has been called if (play_popup_on_listening_) { play_popup_on_listening_ = false; @@ -919,6 +945,10 @@ void Application::SetListeningMode(ListeningMode mode) { SetDeviceState(kDeviceStateListening); } +ListeningMode Application::GetDefaultListeningMode() const { + return aec_mode_ == kAecOff ? kListeningModeAutoStop : kListeningModeRealtime; +} + void Application::Reboot() { ESP_LOGI(TAG, "Rebooting..."); // Disconnect the audio channel diff --git a/main/application.h b/main/application.h index bcb81112..7ca7af4f 100644 --- a/main/application.h +++ b/main/application.h @@ -165,6 +165,7 @@ private: void InitializeProtocol(); void ShowActivationCode(const std::string& code, const std::string& message); void SetListeningMode(ListeningMode mode); + ListeningMode GetDefaultListeningMode() const; // State change handler called by state machine void OnStateChanged(DeviceState old_state, DeviceState new_state); diff --git a/main/audio/audio_service.cc b/main/audio/audio_service.cc index 0cb1c345..fc816b68 100644 --- a/main/audio/audio_service.cc +++ b/main/audio/audio_service.cc @@ -265,27 +265,18 @@ void AudioService::AudioInputTask() { } } - /* Feed the wake word */ - if (bits & AS_EVENT_WAKE_WORD_RUNNING) { + /* Feed the wake word and/or audio processor */ + if (bits & (AS_EVENT_WAKE_WORD_RUNNING | AS_EVENT_AUDIO_PROCESSOR_RUNNING)) { + int samples = 160; // 10ms std::vector data; - int samples = wake_word_->GetFeedSize(); - if (samples > 0) { - if (ReadAudioData(data, 16000, samples)) { + if (ReadAudioData(data, 16000, samples)) { + if (bits & AS_EVENT_WAKE_WORD_RUNNING) { wake_word_->Feed(data); - continue; } - } - } - - /* Feed the audio processor */ - if (bits & AS_EVENT_AUDIO_PROCESSOR_RUNNING) { - std::vector data; - int samples = audio_processor_->GetFeedSize(); - if (samples > 0) { - if (ReadAudioData(data, 16000, samples)) { + if (bits & AS_EVENT_AUDIO_PROCESSOR_RUNNING) { audio_processor_->Feed(std::move(data)); - continue; } + continue; } } diff --git a/main/audio/processors/afe_audio_processor.cc b/main/audio/processors/afe_audio_processor.cc index 0281b135..15e9cc54 100644 --- a/main/audio/processors/afe_audio_processor.cc +++ b/main/audio/processors/afe_audio_processor.cc @@ -92,7 +92,18 @@ void AfeAudioProcessor::Feed(std::vector&& data) { if (afe_data_ == nullptr) { return; } - afe_iface_->feed(afe_data_, data.data()); + + std::lock_guard lock(input_buffer_mutex_); + // Check running state inside lock to avoid TOCTOU race with Stop() + if (!IsRunning()) { + return; + } + input_buffer_.insert(input_buffer_.end(), data.begin(), data.end()); + size_t chunk_size = afe_iface_->get_feed_chunksize(afe_data_) * codec_->input_channels(); + while (input_buffer_.size() >= chunk_size) { + afe_iface_->feed(afe_data_, input_buffer_.data()); + input_buffer_.erase(input_buffer_.begin(), input_buffer_.begin() + chunk_size); + } } void AfeAudioProcessor::Start() { @@ -101,9 +112,12 @@ void AfeAudioProcessor::Start() { void AfeAudioProcessor::Stop() { xEventGroupClearBits(event_group_, PROCESSOR_RUNNING); + + std::lock_guard lock(input_buffer_mutex_); if (afe_data_ != nullptr) { afe_iface_->reset_buffer(afe_data_); } + input_buffer_.clear(); } bool AfeAudioProcessor::IsRunning() { diff --git a/main/audio/processors/afe_audio_processor.h b/main/audio/processors/afe_audio_processor.h index 74d7fa40..fe2f0250 100644 --- a/main/audio/processors/afe_audio_processor.h +++ b/main/audio/processors/afe_audio_processor.h @@ -9,6 +9,7 @@ #include #include #include +#include #include "audio_processor.h" #include "audio_codec.h" @@ -37,6 +38,8 @@ private: AudioCodec* codec_ = nullptr; int frame_samples_ = 0; bool is_speaking_ = false; + std::vector input_buffer_; + std::mutex input_buffer_mutex_; std::vector output_buffer_; void AudioProcessorTask(); diff --git a/main/audio/processors/no_audio_processor.h b/main/audio/processors/no_audio_processor.h index d326d505..13e4897d 100644 --- a/main/audio/processors/no_audio_processor.h +++ b/main/audio/processors/no_audio_processor.h @@ -3,6 +3,7 @@ #include #include +#include #include "audio_processor.h" #include "audio_codec.h" @@ -27,7 +28,7 @@ private: int frame_samples_ = 0; std::function&& data)> output_callback_; std::function vad_state_change_callback_; - bool is_running_ = false; + std::atomic is_running_ = false; }; #endif \ No newline at end of file diff --git a/main/audio/wake_words/afe_wake_word.cc b/main/audio/wake_words/afe_wake_word.cc index d597f8a5..8fc5fefb 100644 --- a/main/audio/wake_words/afe_wake_word.cc +++ b/main/audio/wake_words/afe_wake_word.cc @@ -99,19 +99,30 @@ void AfeWakeWord::Start() { void AfeWakeWord::Stop() { xEventGroupClearBits(event_group_, DETECTION_RUNNING_EVENT); + + std::lock_guard lock(input_buffer_mutex_); if (afe_data_ != nullptr) { afe_iface_->reset_buffer(afe_data_); } + input_buffer_.clear(); } void AfeWakeWord::Feed(const std::vector& data) { if (afe_data_ == nullptr) { return; } + + std::lock_guard lock(input_buffer_mutex_); + // Check running state inside lock to avoid TOCTOU race with Stop() if (!(xEventGroupGetBits(event_group_) & DETECTION_RUNNING_EVENT)) { return; } - afe_iface_->feed(afe_data_, data.data()); + input_buffer_.insert(input_buffer_.end(), data.begin(), data.end()); + size_t chunk_size = afe_iface_->get_feed_chunksize(afe_data_) * codec_->input_channels(); + while (input_buffer_.size() >= chunk_size) { + afe_iface_->feed(afe_data_, input_buffer_.data()); + input_buffer_.erase(input_buffer_.begin(), input_buffer_.begin() + chunk_size); + } } size_t AfeWakeWord::GetFeedSize() { diff --git a/main/audio/wake_words/afe_wake_word.h b/main/audio/wake_words/afe_wake_word.h index 8f5a2807..6c2bf72d 100644 --- a/main/audio/wake_words/afe_wake_word.h +++ b/main/audio/wake_words/afe_wake_word.h @@ -44,6 +44,8 @@ private: std::function wake_word_detected_callback_; AudioCodec* codec_ = nullptr; std::string last_detected_wake_word_; + std::vector input_buffer_; + std::mutex input_buffer_mutex_; TaskHandle_t wake_word_encode_task_ = nullptr; StaticTask_t* wake_word_encode_task_buffer_ = nullptr; diff --git a/main/audio/wake_words/custom_wake_word.cc b/main/audio/wake_words/custom_wake_word.cc index d677d284..982bbe96 100644 --- a/main/audio/wake_words/custom_wake_word.cc +++ b/main/audio/wake_words/custom_wake_word.cc @@ -138,49 +138,64 @@ void CustomWakeWord::Start() { void CustomWakeWord::Stop() { running_ = false; + + std::lock_guard lock(input_buffer_mutex_); + input_buffer_.clear(); } void CustomWakeWord::Feed(const std::vector& data) { - if (multinet_model_data_ == nullptr || !running_) { + if (multinet_model_data_ == nullptr) { + return; + } + + std::lock_guard lock(input_buffer_mutex_); + // Check running state inside lock to avoid TOCTOU race with Stop() + if (!running_) { return; } - esp_mn_state_t mn_state; // If input channels is 2, we need to fetch the left channel data if (codec_->input_channels() == 2) { - auto mono_data = std::vector(data.size() / 2); - for (size_t i = 0, j = 0; i < mono_data.size(); ++i, j += 2) { - mono_data[i] = data[j]; + for (size_t i = 0; i < data.size(); i += 2) { + input_buffer_.push_back(data[i]); } - - StoreWakeWordData(mono_data); - mn_state = multinet_->detect(multinet_model_data_, const_cast(mono_data.data())); } else { - StoreWakeWordData(data); - mn_state = multinet_->detect(multinet_model_data_, const_cast(data.data())); + input_buffer_.insert(input_buffer_.end(), data.begin(), data.end()); } - if (mn_state == ESP_MN_STATE_DETECTING) { - return; - } else if (mn_state == ESP_MN_STATE_DETECTED) { - esp_mn_results_t *mn_result = multinet_->get_results(multinet_model_data_); - for (int i = 0; i < mn_result->num && running_; i++) { - ESP_LOGI(TAG, "Custom wake word detected: command_id=%d, string=%s, prob=%f", - mn_result->command_id[i], mn_result->string, mn_result->prob[i]); - auto& command = commands_[mn_result->command_id[i] - 1]; - if (command.action == "wake") { - last_detected_wake_word_ = command.text; - running_ = false; - - if (wake_word_detected_callback_) { - wake_word_detected_callback_(last_detected_wake_word_); + int chunksize = multinet_->get_samp_chunksize(multinet_model_data_); + while (input_buffer_.size() >= chunksize) { + std::vector chunk(input_buffer_.begin(), input_buffer_.begin() + chunksize); + StoreWakeWordData(chunk); + + esp_mn_state_t mn_state = multinet_->detect(multinet_model_data_, chunk.data()); + + if (mn_state == ESP_MN_STATE_DETECTED) { + esp_mn_results_t *mn_result = multinet_->get_results(multinet_model_data_); + for (int i = 0; i < mn_result->num && running_; i++) { + ESP_LOGI(TAG, "Custom wake word detected: command_id=%d, string=%s, prob=%f", + mn_result->command_id[i], mn_result->string, mn_result->prob[i]); + auto& command = commands_[mn_result->command_id[i] - 1]; + if (command.action == "wake") { + last_detected_wake_word_ = command.text; + running_ = false; + input_buffer_.clear(); + + if (wake_word_detected_callback_) { + wake_word_detected_callback_(last_detected_wake_word_); + } } } + multinet_->clean(multinet_model_data_); + } else if (mn_state == ESP_MN_STATE_TIMEOUT) { + ESP_LOGD(TAG, "Command word detection timeout, cleaning state"); + multinet_->clean(multinet_model_data_); } - multinet_->clean(multinet_model_data_); - } else if (mn_state == ESP_MN_STATE_TIMEOUT) { - ESP_LOGD(TAG, "Command word detection timeout, cleaning state"); - multinet_->clean(multinet_model_data_); + + if (!running_) { + break; + } + input_buffer_.erase(input_buffer_.begin(), input_buffer_.begin() + chunksize); } } diff --git a/main/audio/wake_words/custom_wake_word.h b/main/audio/wake_words/custom_wake_word.h index d4e6d8c3..645ad1b3 100644 --- a/main/audio/wake_words/custom_wake_word.h +++ b/main/audio/wake_words/custom_wake_word.h @@ -53,6 +53,8 @@ private: AudioCodec* codec_ = nullptr; std::string last_detected_wake_word_; std::atomic running_ = false; + std::vector input_buffer_; + std::mutex input_buffer_mutex_; TaskHandle_t wake_word_encode_task_ = nullptr; StaticTask_t* wake_word_encode_task_buffer_ = nullptr; diff --git a/main/audio/wake_words/esp_wake_word.cc b/main/audio/wake_words/esp_wake_word.cc index d4aaf9d0..930aa2d6 100644 --- a/main/audio/wake_words/esp_wake_word.cc +++ b/main/audio/wake_words/esp_wake_word.cc @@ -54,21 +54,44 @@ void EspWakeWord::Start() { void EspWakeWord::Stop() { running_ = false; + + std::lock_guard lock(input_buffer_mutex_); + input_buffer_.clear(); } void EspWakeWord::Feed(const std::vector& data) { - if (wakenet_data_ == nullptr || !running_) { + if (wakenet_data_ == nullptr) { return; } - int res = wakenet_iface_->detect(wakenet_data_, (int16_t *)data.data()); - if (res > 0) { - last_detected_wake_word_ = wakenet_iface_->get_word_name(wakenet_data_, res); - running_ = false; + std::lock_guard lock(input_buffer_mutex_); + // Check running state inside lock to avoid TOCTOU race with Stop() + if (!running_) { + return; + } - if (wake_word_detected_callback_) { - wake_word_detected_callback_(last_detected_wake_word_); + if (codec_->input_channels() == 2) { + for (size_t i = 0; i < data.size(); i += 2) { + input_buffer_.push_back(data[i]); } + } else { + input_buffer_.insert(input_buffer_.end(), data.begin(), data.end()); + } + + int chunksize = wakenet_iface_->get_samp_chunksize(wakenet_data_); + while (input_buffer_.size() >= chunksize) { + int res = wakenet_iface_->detect(wakenet_data_, input_buffer_.data()); + if (res > 0) { + last_detected_wake_word_ = wakenet_iface_->get_word_name(wakenet_data_, res); + running_ = false; + input_buffer_.clear(); + + if (wake_word_detected_callback_) { + wake_word_detected_callback_(last_detected_wake_word_); + } + break; + } + input_buffer_.erase(input_buffer_.begin(), input_buffer_.begin() + chunksize); } } diff --git a/main/audio/wake_words/esp_wake_word.h b/main/audio/wake_words/esp_wake_word.h index 9a1d73aa..87e792da 100644 --- a/main/audio/wake_words/esp_wake_word.h +++ b/main/audio/wake_words/esp_wake_word.h @@ -9,6 +9,7 @@ #include #include #include +#include #include "audio_codec.h" #include "wake_word.h" @@ -37,6 +38,8 @@ private: std::function wake_word_detected_callback_; std::string last_detected_wake_word_; + std::vector input_buffer_; + std::mutex input_buffer_mutex_; }; #endif