diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 6762f440c5a058..4fa8b7db02a8b1 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -111,6 +111,7 @@ TIMEOUT_ACK = 10 RECONNECT_INTERVAL_SECONDS = 10 +MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1 MAX_SUBSCRIBES_PER_CALL = 500 MAX_UNSUBSCRIBES_PER_CALL = 500 @@ -893,14 +894,27 @@ async def _async_perform_subscriptions(self) -> None: if not self._pending_subscriptions: return - subscriptions: dict[str, int] = self._pending_subscriptions + # Split out the wildcard subscriptions, we subscribe to them one by one + pending_subscriptions: dict[str, int] = self._pending_subscriptions + pending_wildcard_subscriptions = { + subscription.topic: pending_subscriptions.pop(subscription.topic) + for subscription in self._wildcard_subscriptions + if subscription.topic in pending_subscriptions + } + self._pending_subscriptions = {} - subscription_list = list(subscriptions.items()) debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) - for chunk in chunked_or_all(subscription_list, MAX_SUBSCRIBES_PER_CALL): + for chunk in chain( + chunked_or_all( + pending_wildcard_subscriptions.items(), MAX_WILDCARD_SUBSCRIBES_PER_CALL + ), + chunked_or_all(pending_subscriptions.items(), MAX_SUBSCRIBES_PER_CALL), + ): chunk_list = list(chunk) + if not chunk_list: + continue result, mid = self._mqttc.subscribe(chunk_list) diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index cf2941a36656a9..8e379633674ce3 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -5,6 +5,7 @@ import asyncio from collections import deque import functools +from itertools import chain import logging import re import time @@ -238,10 +239,6 @@ def async_discovery_message_received(msg: ReceiveMessage) -> None: # noqa: C901 component, node_id, object_id = match.groups() - if component not in SUPPORTED_COMPONENTS: - _LOGGER.warning("Integration %s is not supported", component) - return - if payload: try: discovery_payload = MQTTDiscoveryPayload(json_loads_object(payload)) @@ -351,9 +348,15 @@ def discovery_done(_: Any) -> None: 0, job_type=HassJobType.Callback, ) - for topic in ( - f"{discovery_topic}/+/+/config", - f"{discovery_topic}/+/+/+/config", + for topic in chain( + ( + f"{discovery_topic}/{component}/+/config" + for component in SUPPORTED_COMPONENTS + ), + ( + f"{discovery_topic}/{component}/+/+/config" + for component in SUPPORTED_COMPONENTS + ), ) ] diff --git a/tests/components/mqtt/test_client.py b/tests/components/mqtt/test_client.py index c5887016f2e9ae..dcded7d187a9aa 100644 --- a/tests/components/mqtt/test_client.py +++ b/tests/components/mqtt/test_client.py @@ -13,6 +13,7 @@ from homeassistant.components import mqtt from homeassistant.components.mqtt.client import RECONNECT_INTERVAL_SECONDS +from homeassistant.components.mqtt.const import SUPPORTED_COMPONENTS from homeassistant.components.mqtt.models import MessageCallbackType, ReceiveMessage from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState from homeassistant.const import ( @@ -1614,8 +1615,9 @@ async def test_subscription_done_when_birth_message_is_sent( """Test sending birth message until initial subscription has been completed.""" mqtt_client_mock = setup_with_birth_msg_client_mock subscribe_calls = help_all_subscribe_calls(mqtt_client_mock) - assert ("homeassistant/+/+/config", 0) in subscribe_calls - assert ("homeassistant/+/+/+/config", 0) in subscribe_calls + for component in SUPPORTED_COMPONENTS: + assert (f"homeassistant/{component}/+/config", 0) in subscribe_calls + assert (f"homeassistant/{component}/+/+/config", 0) in subscribe_calls mqtt_client_mock.publish.assert_called_with( "homeassistant/status", "online", 0, False ) diff --git a/tests/components/mqtt/test_common.py b/tests/components/mqtt/test_common.py index f7ebd039d1a127..c135c29ebc5313 100644 --- a/tests/components/mqtt/test_common.py +++ b/tests/components/mqtt/test_common.py @@ -16,7 +16,10 @@ from homeassistant import config as module_hass_config from homeassistant.components import mqtt from homeassistant.components.mqtt import debug_info -from homeassistant.components.mqtt.const import MQTT_CONNECTION_STATE +from homeassistant.components.mqtt.const import ( + MQTT_CONNECTION_STATE, + SUPPORTED_COMPONENTS, +) from homeassistant.components.mqtt.mixins import MQTT_ATTRIBUTES_BLOCKED from homeassistant.components.mqtt.models import PublishPayloadType from homeassistant.config_entries import ConfigEntryState @@ -75,9 +78,12 @@ def help_all_subscribe_calls(mqtt_client_mock: MqttMockPahoClient) -> list[Any]: """Test of a call.""" all_calls = [] - for calls in mqtt_client_mock.subscribe.mock_calls: - for call in calls[1]: - all_calls.extend(call) + for call_l1 in mqtt_client_mock.subscribe.mock_calls: + if isinstance(call_l1[1][0], list): + for call_l2 in call_l1[1]: + all_calls.extend(call_l2) + else: + all_calls.append(call_l1[1]) return all_calls @@ -1178,7 +1184,10 @@ async def help_test_entity_id_update_subscriptions( state = hass.states.get(f"{domain}.test") assert state is not None - assert mqtt_mock.async_subscribe.call_count == len(topics) + 2 + DISCOVERY_COUNT + assert ( + mqtt_mock.async_subscribe.call_count + == len(topics) + 2 * len(SUPPORTED_COMPONENTS) + DISCOVERY_COUNT + ) for topic in topics: mqtt_mock.async_subscribe.assert_any_call( topic, ANY, ANY, ANY, HassJobType.Callback diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index 58de3c53c52a45..7f58fc75daecef 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -15,6 +15,7 @@ ABBREVIATIONS, DEVICE_ABBREVIATIONS, ) +from homeassistant.components.mqtt.const import SUPPORTED_COMPONENTS from homeassistant.components.mqtt.discovery import ( MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_NEW, @@ -73,13 +74,10 @@ async def test_subscribing_config_topic( discovery_topic = "homeassistant" await async_start(hass, discovery_topic, entry) - call_args1 = mqtt_mock.async_subscribe.mock_calls[0][1] - assert call_args1[2] == 0 - call_args2 = mqtt_mock.async_subscribe.mock_calls[1][1] - assert call_args2[2] == 0 - topics = [call_args1[0], call_args2[0]] - assert discovery_topic + "/+/+/config" in topics - assert discovery_topic + "/+/+/+/config" in topics + topics = [call[1][0] for call in mqtt_mock.async_subscribe.mock_calls] + for component in SUPPORTED_COMPONENTS: + assert f"{discovery_topic}/{component}/+/config" in topics + assert f"{discovery_topic}/{component}/+/+/config" in topics @pytest.mark.parametrize( @@ -198,8 +196,6 @@ async def test_only_valid_components( await hass.async_block_till_done() - assert f"Integration {invalid_component} is not supported" in caplog.text - assert not mock_dispatcher_send.called