From ee59303d3c60eabe8aba13239af4050353f1d193 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Fri, 6 Sep 2024 10:57:09 -0500 Subject: [PATCH] Use first media player announcement format for TTS (#125237) * Use ANNOUNCEMENT format from first media player for tts * Fix formatting --------- Co-authored-by: Paulus Schoutsen --- .../components/assist_satellite/entity.py | 11 ++- .../components/esphome/assist_satellite.py | 27 +++++++ .../components/esphome/entry_data.py | 4 + .../components/esphome/media_player.py | 10 ++- .../assist_satellite/test_entity.py | 2 +- .../esphome/test_assist_satellite.py | 73 ++++++++++++++++++- 6 files changed, 123 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/assist_satellite/entity.py b/homeassistant/components/assist_satellite/entity.py index 6ec40ae24f7c58..38973f15f55765 100644 --- a/homeassistant/components/assist_satellite/entity.py +++ b/homeassistant/components/assist_satellite/entity.py @@ -72,6 +72,7 @@ class AssistSatelliteEntity(entity.Entity): _run_has_tts: bool = False _is_announcing = False _wake_word_intercept_future: asyncio.Future[str | None] | None = None + _attr_tts_options: dict[str, Any] | None = None __assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD @@ -91,6 +92,11 @@ def vad_sensitivity_entity_id(self) -> str | None: """Entity ID of the VAD sensitivity to use for the next conversation.""" return self._attr_vad_sensitivity_entity_id + @property + def tts_options(self) -> dict[str, Any] | None: + """Options passed for text-to-speech.""" + return self._attr_tts_options + async def async_intercept_wake_word(self) -> str | None: """Intercept the next wake word from the satellite. @@ -137,6 +143,9 @@ async def async_internal_announce( if pipeline.tts_voice is not None: tts_options[tts.ATTR_VOICE] = pipeline.tts_voice + if self.tts_options is not None: + tts_options.update(self.tts_options) + media_id = tts_generate_media_source_id( self.hass, message, @@ -253,7 +262,7 @@ async def async_accept_pipeline_from_satellite( pipeline_id=self._resolve_pipeline(), conversation_id=self._conversation_id, device_id=device_id, - tts_audio_output="wav", + tts_audio_output=self.tts_options, wake_word_phrase=wake_word_phrase, audio_settings=AudioSettings( silence_seconds=self._resolve_vad_sensitivity() diff --git a/homeassistant/components/esphome/assist_satellite.py b/homeassistant/components/esphome/assist_satellite.py index 48bb9ec55070ef..f84940eadc4585 100644 --- a/homeassistant/components/esphome/assist_satellite.py +++ b/homeassistant/components/esphome/assist_satellite.py @@ -6,12 +6,14 @@ from collections.abc import AsyncIterable from functools import partial import io +from itertools import chain import logging import socket from typing import Any, cast import wave from aioesphomeapi import ( + MediaPlayerFormatPurpose, VoiceAssistantAudioSettings, VoiceAssistantCommandFlag, VoiceAssistantEventType, @@ -288,6 +290,18 @@ async def handle_pipeline_start( end_stage = PipelineStage.TTS + if feature_flags & VoiceAssistantFeature.SPEAKER: + # Stream WAV audio + self._attr_tts_options = { + tts.ATTR_PREFERRED_FORMAT: "wav", + tts.ATTR_PREFERRED_SAMPLE_RATE: 16000, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1, + tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, + } + else: + # ANNOUNCEMENT format from media player + self._update_tts_format() + # Run the pipeline _LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage) self.entry_data.async_set_assist_pipeline_state(True) @@ -340,6 +354,19 @@ def handle_timer_event( timer_info.is_active, ) + def _update_tts_format(self) -> None: + """Update the TTS format from the first media player.""" + for supported_format in chain(*self.entry_data.media_player_formats.values()): + # Find first announcement format + if supported_format.purpose == MediaPlayerFormatPurpose.ANNOUNCEMENT: + self._attr_tts_options = { + tts.ATTR_PREFERRED_FORMAT: supported_format.format, + tts.ATTR_PREFERRED_SAMPLE_RATE: supported_format.sample_rate, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS: supported_format.num_channels, + tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, + } + break + async def _stream_tts_audio( self, media_id: str, diff --git a/homeassistant/components/esphome/entry_data.py b/homeassistant/components/esphome/entry_data.py index 6fc40612c489f0..f1b5218eec7083 100644 --- a/homeassistant/components/esphome/entry_data.py +++ b/homeassistant/components/esphome/entry_data.py @@ -31,6 +31,7 @@ LightInfo, LockInfo, MediaPlayerInfo, + MediaPlayerSupportedFormat, NumberInfo, SelectInfo, SensorInfo, @@ -148,6 +149,9 @@ class RuntimeEntryData: tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]] ] = field(default_factory=dict) original_options: dict[str, Any] = field(default_factory=dict) + media_player_formats: dict[str, list[MediaPlayerSupportedFormat]] = field( + default_factory=lambda: defaultdict(list) + ) @property def name(self) -> str: diff --git a/homeassistant/components/esphome/media_player.py b/homeassistant/components/esphome/media_player.py index f7c5d7011f87bb..4d57552bb19ea7 100644 --- a/homeassistant/components/esphome/media_player.py +++ b/homeassistant/components/esphome/media_player.py @@ -3,7 +3,7 @@ from __future__ import annotations from functools import partial -from typing import Any +from typing import Any, cast from aioesphomeapi import ( EntityInfo, @@ -66,6 +66,9 @@ def _on_static_info_update(self, static_info: EntityInfo) -> None: if self._static_info.supports_pause: flags |= MediaPlayerEntityFeature.PAUSE | MediaPlayerEntityFeature.PLAY self._attr_supported_features = flags + self._entry_data.media_player_formats[self.entity_id] = cast( + MediaPlayerInfo, static_info + ).supported_formats @property @esphome_state_property @@ -103,6 +106,11 @@ async def async_play_media( self._key, media_url=media_id, announcement=announcement ) + async def async_will_remove_from_hass(self) -> None: + """Handle entity being removed.""" + await super().async_will_remove_from_hass() + self._entry_data.media_player_formats.pop(self.entity_id, None) + async def async_browse_media( self, media_content_type: MediaType | str | None = None, diff --git a/tests/components/assist_satellite/test_entity.py b/tests/components/assist_satellite/test_entity.py index 2e4caca030bc9f..ec52d8abff4383 100644 --- a/tests/components/assist_satellite/test_entity.py +++ b/tests/components/assist_satellite/test_entity.py @@ -61,7 +61,7 @@ async def test_entity_state( assert kwargs["stt_stream"] is audio_stream assert kwargs["pipeline_id"] is None assert kwargs["device_id"] is None - assert kwargs["tts_audio_output"] == "wav" + assert kwargs["tts_audio_output"] is None assert kwargs["wake_word_phrase"] is None assert kwargs["audio_settings"] == AudioSettings( silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT) diff --git a/tests/components/esphome/test_assist_satellite.py b/tests/components/esphome/test_assist_satellite.py index f024ca3b078342..1c7f7320a85a80 100644 --- a/tests/components/esphome/test_assist_satellite.py +++ b/tests/components/esphome/test_assist_satellite.py @@ -11,6 +11,9 @@ APIClient, EntityInfo, EntityState, + MediaPlayerFormatPurpose, + MediaPlayerInfo, + MediaPlayerSupportedFormat, UserService, VoiceAssistantAudioSettings, VoiceAssistantCommandFlag, @@ -20,7 +23,7 @@ ) import pytest -from homeassistant.components import assist_satellite +from homeassistant.components import assist_satellite, tts from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType from homeassistant.components.assist_satellite.entity import ( AssistSatelliteEntity, @@ -820,3 +823,71 @@ async def get_slow_wav( VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}, ) + + +async def test_tts_format_from_media_player( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test that the text-to-speech format is pulled from the first media player.""" + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[ + MediaPlayerInfo( + object_id="mymedia_player", + key=1, + name="my media_player", + unique_id="my_media_player", + supports_pause=True, + supported_formats=[ + MediaPlayerSupportedFormat( + format="flac", + sample_rate=48000, + num_channels=2, + purpose=MediaPlayerFormatPurpose.DEFAULT, + ), + # This is the format that should be used for tts + MediaPlayerSupportedFormat( + format="mp3", + sample_rate=22050, + num_channels=1, + purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT, + ), + ], + ) + ], + user_service=[], + states=[], + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + }, + ) + await hass.async_block_till_done() + + satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) + assert satellite is not None + + with patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + ) as mock_pipeline_from_audio_stream: + await satellite.handle_pipeline_start( + conversation_id="", + flags=0, + audio_settings=VoiceAssistantAudioSettings(), + wake_word_phrase=None, + ) + + mock_pipeline_from_audio_stream.assert_called_once() + kwargs = mock_pipeline_from_audio_stream.call_args_list[0].kwargs + + # Should be ANNOUNCEMENT format from media player + assert kwargs.get("tts_audio_output") == { + tts.ATTR_PREFERRED_FORMAT: "mp3", + tts.ATTR_PREFERRED_SAMPLE_RATE: 22050, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1, + tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, + }