Skip to content

Commit

Permalink
Use first media player announcement format for TTS (#125237)
Browse files Browse the repository at this point in the history
* Use ANNOUNCEMENT format from first media player for tts

* Fix formatting

---------

Co-authored-by: Paulus Schoutsen <[email protected]>
  • Loading branch information
synesthesiam and balloob authored Sep 6, 2024
1 parent 20639b0 commit ee59303
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 4 deletions.
11 changes: 10 additions & 1 deletion homeassistant/components/assist_satellite/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class AssistSatelliteEntity(entity.Entity):
_run_has_tts: bool = False
_is_announcing = False
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
_attr_tts_options: dict[str, Any] | None = None

__assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD

Expand All @@ -91,6 +92,11 @@ def vad_sensitivity_entity_id(self) -> str | None:
"""Entity ID of the VAD sensitivity to use for the next conversation."""
return self._attr_vad_sensitivity_entity_id

@property
def tts_options(self) -> dict[str, Any] | None:
"""Options passed for text-to-speech."""
return self._attr_tts_options

async def async_intercept_wake_word(self) -> str | None:
"""Intercept the next wake word from the satellite.
Expand Down Expand Up @@ -137,6 +143,9 @@ async def async_internal_announce(
if pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice

if self.tts_options is not None:
tts_options.update(self.tts_options)

media_id = tts_generate_media_source_id(
self.hass,
message,
Expand Down Expand Up @@ -253,7 +262,7 @@ async def async_accept_pipeline_from_satellite(
pipeline_id=self._resolve_pipeline(),
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output="wav",
tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()
Expand Down
27 changes: 27 additions & 0 deletions homeassistant/components/esphome/assist_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from collections.abc import AsyncIterable
from functools import partial
import io
from itertools import chain
import logging
import socket
from typing import Any, cast
import wave

from aioesphomeapi import (
MediaPlayerFormatPurpose,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
Expand Down Expand Up @@ -288,6 +290,18 @@ async def handle_pipeline_start(

end_stage = PipelineStage.TTS

if feature_flags & VoiceAssistantFeature.SPEAKER:
# Stream WAV audio
self._attr_tts_options = {
tts.ATTR_PREFERRED_FORMAT: "wav",
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}
else:
# ANNOUNCEMENT format from media player
self._update_tts_format()

# Run the pipeline
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
self.entry_data.async_set_assist_pipeline_state(True)
Expand Down Expand Up @@ -340,6 +354,19 @@ def handle_timer_event(
timer_info.is_active,
)

def _update_tts_format(self) -> None:
"""Update the TTS format from the first media player."""
for supported_format in chain(*self.entry_data.media_player_formats.values()):
# Find first announcement format
if supported_format.purpose == MediaPlayerFormatPurpose.ANNOUNCEMENT:
self._attr_tts_options = {
tts.ATTR_PREFERRED_FORMAT: supported_format.format,
tts.ATTR_PREFERRED_SAMPLE_RATE: supported_format.sample_rate,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: supported_format.num_channels,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}
break

async def _stream_tts_audio(
self,
media_id: str,
Expand Down
4 changes: 4 additions & 0 deletions homeassistant/components/esphome/entry_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
LightInfo,
LockInfo,
MediaPlayerInfo,
MediaPlayerSupportedFormat,
NumberInfo,
SelectInfo,
SensorInfo,
Expand Down Expand Up @@ -148,6 +149,9 @@ class RuntimeEntryData:
tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]]
] = field(default_factory=dict)
original_options: dict[str, Any] = field(default_factory=dict)
media_player_formats: dict[str, list[MediaPlayerSupportedFormat]] = field(
default_factory=lambda: defaultdict(list)
)

@property
def name(self) -> str:
Expand Down
10 changes: 9 additions & 1 deletion homeassistant/components/esphome/media_player.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from functools import partial
from typing import Any
from typing import Any, cast

from aioesphomeapi import (
EntityInfo,
Expand Down Expand Up @@ -66,6 +66,9 @@ def _on_static_info_update(self, static_info: EntityInfo) -> None:
if self._static_info.supports_pause:
flags |= MediaPlayerEntityFeature.PAUSE | MediaPlayerEntityFeature.PLAY
self._attr_supported_features = flags
self._entry_data.media_player_formats[self.entity_id] = cast(
MediaPlayerInfo, static_info
).supported_formats

@property
@esphome_state_property
Expand Down Expand Up @@ -103,6 +106,11 @@ async def async_play_media(
self._key, media_url=media_id, announcement=announcement
)

async def async_will_remove_from_hass(self) -> None:
"""Handle entity being removed."""
await super().async_will_remove_from_hass()
self._entry_data.media_player_formats.pop(self.entity_id, None)

async def async_browse_media(
self,
media_content_type: MediaType | str | None = None,
Expand Down
2 changes: 1 addition & 1 deletion tests/components/assist_satellite/test_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def test_entity_state(
assert kwargs["stt_stream"] is audio_stream
assert kwargs["pipeline_id"] is None
assert kwargs["device_id"] is None
assert kwargs["tts_audio_output"] == "wav"
assert kwargs["tts_audio_output"] is None
assert kwargs["wake_word_phrase"] is None
assert kwargs["audio_settings"] == AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
Expand Down
73 changes: 72 additions & 1 deletion tests/components/esphome/test_assist_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
APIClient,
EntityInfo,
EntityState,
MediaPlayerFormatPurpose,
MediaPlayerInfo,
MediaPlayerSupportedFormat,
UserService,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
Expand All @@ -20,7 +23,7 @@
)
import pytest

from homeassistant.components import assist_satellite
from homeassistant.components import assist_satellite, tts
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
from homeassistant.components.assist_satellite.entity import (
AssistSatelliteEntity,
Expand Down Expand Up @@ -820,3 +823,71 @@ async def get_slow_wav(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
{},
)


async def test_tts_format_from_media_player(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that the text-to-speech format is pulled from the first media player."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[
MediaPlayerInfo(
object_id="mymedia_player",
key=1,
name="my media_player",
unique_id="my_media_player",
supports_pause=True,
supported_formats=[
MediaPlayerSupportedFormat(
format="flac",
sample_rate=48000,
num_channels=2,
purpose=MediaPlayerFormatPurpose.DEFAULT,
),
# This is the format that should be used for tts
MediaPlayerSupportedFormat(
format="mp3",
sample_rate=22050,
num_channels=1,
purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT,
),
],
)
],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
},
)
await hass.async_block_till_done()

satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None

with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_pipeline_from_audio_stream:
await satellite.handle_pipeline_start(
conversation_id="",
flags=0,
audio_settings=VoiceAssistantAudioSettings(),
wake_word_phrase=None,
)

mock_pipeline_from_audio_stream.assert_called_once()
kwargs = mock_pipeline_from_audio_stream.call_args_list[0].kwargs

# Should be ANNOUNCEMENT format from media player
assert kwargs.get("tts_audio_output") == {
tts.ATTR_PREFERRED_FORMAT: "mp3",
tts.ATTR_PREFERRED_SAMPLE_RATE: 22050,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}

0 comments on commit ee59303

Please sign in to comment.