Skip to content

Commit

Permalink
Subscribe per component for MQTT discovery (#119974)
Browse files Browse the repository at this point in the history
* Subscribe per component for MQTT discovery

* Use single assignment

* Handle wildcard subscriptions first

* Split subsRecription handling, update helper

* Fix help_all_subscribe_calls

* Fix import

* Fix test

* Update import order

* Undo move self._last_subscribe

* Recover removed test

* Revert not needed changes to binary_sensor platform tests

* Revert line removal

* Rework interation of discovery topics

* Reduce

* Add comment

* Move comment

* Chain subscriptions
  • Loading branch information
jbouwh authored Aug 20, 2024
1 parent a1e3e7f commit b74aced
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 26 deletions.
20 changes: 17 additions & 3 deletions homeassistant/components/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
TIMEOUT_ACK = 10
RECONNECT_INTERVAL_SECONDS = 10

MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
MAX_SUBSCRIBES_PER_CALL = 500
MAX_UNSUBSCRIBES_PER_CALL = 500

Expand Down Expand Up @@ -893,14 +894,27 @@ async def _async_perform_subscriptions(self) -> None:
if not self._pending_subscriptions:
return

subscriptions: dict[str, int] = self._pending_subscriptions
# Split out the wildcard subscriptions, we subscribe to them one by one
pending_subscriptions: dict[str, int] = self._pending_subscriptions
pending_wildcard_subscriptions = {
subscription.topic: pending_subscriptions.pop(subscription.topic)
for subscription in self._wildcard_subscriptions
if subscription.topic in pending_subscriptions
}

self._pending_subscriptions = {}

subscription_list = list(subscriptions.items())
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)

for chunk in chunked_or_all(subscription_list, MAX_SUBSCRIBES_PER_CALL):
for chunk in chain(
chunked_or_all(
pending_wildcard_subscriptions.items(), MAX_WILDCARD_SUBSCRIBES_PER_CALL
),
chunked_or_all(pending_subscriptions.items(), MAX_SUBSCRIBES_PER_CALL),
):
chunk_list = list(chunk)
if not chunk_list:
continue

result, mid = self._mqttc.subscribe(chunk_list)

Expand Down
17 changes: 10 additions & 7 deletions homeassistant/components/mqtt/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
from collections import deque
import functools
from itertools import chain
import logging
import re
import time
Expand Down Expand Up @@ -238,10 +239,6 @@ def async_discovery_message_received(msg: ReceiveMessage) -> None: # noqa: C901

component, node_id, object_id = match.groups()

if component not in SUPPORTED_COMPONENTS:
_LOGGER.warning("Integration %s is not supported", component)
return

if payload:
try:
discovery_payload = MQTTDiscoveryPayload(json_loads_object(payload))
Expand Down Expand Up @@ -351,9 +348,15 @@ def discovery_done(_: Any) -> None:
0,
job_type=HassJobType.Callback,
)
for topic in (
f"{discovery_topic}/+/+/config",
f"{discovery_topic}/+/+/+/config",
for topic in chain(
(
f"{discovery_topic}/{component}/+/config"
for component in SUPPORTED_COMPONENTS
),
(
f"{discovery_topic}/{component}/+/+/config"
for component in SUPPORTED_COMPONENTS
),
)
]

Expand Down
6 changes: 4 additions & 2 deletions tests/components/mqtt/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from homeassistant.components import mqtt
from homeassistant.components.mqtt.client import RECONNECT_INTERVAL_SECONDS
from homeassistant.components.mqtt.const import SUPPORTED_COMPONENTS
from homeassistant.components.mqtt.models import MessageCallbackType, ReceiveMessage
from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState
from homeassistant.const import (
Expand Down Expand Up @@ -1614,8 +1615,9 @@ async def test_subscription_done_when_birth_message_is_sent(
"""Test sending birth message until initial subscription has been completed."""
mqtt_client_mock = setup_with_birth_msg_client_mock
subscribe_calls = help_all_subscribe_calls(mqtt_client_mock)
assert ("homeassistant/+/+/config", 0) in subscribe_calls
assert ("homeassistant/+/+/+/config", 0) in subscribe_calls
for component in SUPPORTED_COMPONENTS:
assert (f"homeassistant/{component}/+/config", 0) in subscribe_calls
assert (f"homeassistant/{component}/+/+/config", 0) in subscribe_calls
mqtt_client_mock.publish.assert_called_with(
"homeassistant/status", "online", 0, False
)
Expand Down
19 changes: 14 additions & 5 deletions tests/components/mqtt/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from homeassistant import config as module_hass_config
from homeassistant.components import mqtt
from homeassistant.components.mqtt import debug_info
from homeassistant.components.mqtt.const import MQTT_CONNECTION_STATE
from homeassistant.components.mqtt.const import (
MQTT_CONNECTION_STATE,
SUPPORTED_COMPONENTS,
)
from homeassistant.components.mqtt.mixins import MQTT_ATTRIBUTES_BLOCKED
from homeassistant.components.mqtt.models import PublishPayloadType
from homeassistant.config_entries import ConfigEntryState
Expand Down Expand Up @@ -75,9 +78,12 @@
def help_all_subscribe_calls(mqtt_client_mock: MqttMockPahoClient) -> list[Any]:
"""Test of a call."""
all_calls = []
for calls in mqtt_client_mock.subscribe.mock_calls:
for call in calls[1]:
all_calls.extend(call)
for call_l1 in mqtt_client_mock.subscribe.mock_calls:
if isinstance(call_l1[1][0], list):
for call_l2 in call_l1[1]:
all_calls.extend(call_l2)
else:
all_calls.append(call_l1[1])
return all_calls


Expand Down Expand Up @@ -1178,7 +1184,10 @@ async def help_test_entity_id_update_subscriptions(

state = hass.states.get(f"{domain}.test")
assert state is not None
assert mqtt_mock.async_subscribe.call_count == len(topics) + 2 + DISCOVERY_COUNT
assert (
mqtt_mock.async_subscribe.call_count
== len(topics) + 2 * len(SUPPORTED_COMPONENTS) + DISCOVERY_COUNT
)
for topic in topics:
mqtt_mock.async_subscribe.assert_any_call(
topic, ANY, ANY, ANY, HassJobType.Callback
Expand Down
14 changes: 5 additions & 9 deletions tests/components/mqtt/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ABBREVIATIONS,
DEVICE_ABBREVIATIONS,
)
from homeassistant.components.mqtt.const import SUPPORTED_COMPONENTS
from homeassistant.components.mqtt.discovery import (
MQTT_DISCOVERY_DONE,
MQTT_DISCOVERY_NEW,
Expand Down Expand Up @@ -73,13 +74,10 @@ async def test_subscribing_config_topic(
discovery_topic = "homeassistant"
await async_start(hass, discovery_topic, entry)

call_args1 = mqtt_mock.async_subscribe.mock_calls[0][1]
assert call_args1[2] == 0
call_args2 = mqtt_mock.async_subscribe.mock_calls[1][1]
assert call_args2[2] == 0
topics = [call_args1[0], call_args2[0]]
assert discovery_topic + "/+/+/config" in topics
assert discovery_topic + "/+/+/+/config" in topics
topics = [call[1][0] for call in mqtt_mock.async_subscribe.mock_calls]
for component in SUPPORTED_COMPONENTS:
assert f"{discovery_topic}/{component}/+/config" in topics
assert f"{discovery_topic}/{component}/+/+/config" in topics


@pytest.mark.parametrize(
Expand Down Expand Up @@ -198,8 +196,6 @@ async def test_only_valid_components(

await hass.async_block_till_done()

assert f"Integration {invalid_component} is not supported" in caplog.text

assert not mock_dispatcher_send.called


Expand Down

0 comments on commit b74aced

Please sign in to comment.