diff --git a/supervisor/supervisor.py b/supervisor/supervisor.py index 082962930ce..fccb9ac3537 100644 --- a/supervisor/supervisor.py +++ b/supervisor/supervisor.py @@ -41,7 +41,7 @@ def _check_connectivity_throttle_period(coresys: CoreSys, *_) -> timedelta: if coresys.supervisor.connectivity: return timedelta(minutes=10) - return timedelta() + return timedelta(seconds=30) class Supervisor(CoreSysAttributes): diff --git a/tests/common.py b/tests/common.py index ad489a566a2..d7b507e9757 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,12 +1,15 @@ """Common test functions.""" +from datetime import datetime from importlib import import_module +from inspect import getclosurevars import json from pathlib import Path from typing import Any from dbus_fast.aio.message_bus import MessageBus +from supervisor.jobs.decorator import Job from supervisor.resolution.validate import get_valid_modules from supervisor.utils.yaml import read_yaml_file @@ -82,3 +85,17 @@ async def mock_dbus_services( services[module] = service_module.setup(to_mock[module]).export(bus) return services + + +def get_job_decorator(func) -> Job: + """Get Job object of decorated function.""" + # Access the closure of the wrapper function + job = getclosurevars(func).nonlocals["self"] + if not isinstance(job, Job): + raise TypeError(f"{func.__qualname__} is not a Job") + return job + + +def reset_last_call(func, group: str | None = None) -> None: + """Reset last call for a function using the Job decorator.""" + get_job_decorator(func).set_last_call(datetime.min, group) diff --git a/tests/jobs/test_job_decorator.py b/tests/jobs/test_job_decorator.py index 426973d7c75..8d618d94805 100644 --- a/tests/jobs/test_job_decorator.py +++ b/tests/jobs/test_job_decorator.py @@ -26,8 +26,11 @@ from supervisor.os.manager import OSManager from supervisor.plugins.audio import PluginAudio from supervisor.resolution.const import UnhealthyReason +from supervisor.supervisor import Supervisor from supervisor.utils.dt import utcnow +from tests.common import reset_last_call + async def test_healthy(coresys: CoreSys, caplog: pytest.LogCaptureFixture): """Test the healty decorator.""" @@ -73,6 +76,7 @@ async def test_internet( ): """Test the internet decorator.""" coresys.core.state = CoreState.RUNNING + reset_last_call(Supervisor.check_connectivity) class TestClass: """Test class.""" diff --git a/tests/test_supervisor.py b/tests/test_supervisor.py index 0254de9c6f8..42ec8d5baa7 100644 --- a/tests/test_supervisor.py +++ b/tests/test_supervisor.py @@ -1,6 +1,6 @@ """Test supervisor object.""" -from datetime import datetime +from datetime import datetime, timedelta import errno from unittest.mock import AsyncMock, Mock, PropertyMock, patch @@ -8,6 +8,7 @@ from aiohttp.client_exceptions import ClientError from awesomeversion import AwesomeVersion import pytest +from time_machine import travel from supervisor.const import UpdateChannel from supervisor.coresys import CoreSys @@ -22,6 +23,8 @@ from supervisor.resolution.data import Issue from supervisor.supervisor import Supervisor +from tests.common import reset_last_call + @pytest.fixture(name="websession", scope="function") async def fixture_webession(coresys: CoreSys) -> AsyncMock: @@ -58,21 +61,33 @@ async def test_connectivity_check( assert supervisor_unthrottled.connectivity is connectivity -@pytest.mark.parametrize("side_effect,call_count", [(ClientError(), 3), (None, 1)]) +@pytest.mark.parametrize( + "side_effect,call_interval,throttled", + [ + (None, timedelta(minutes=5), True), + (None, timedelta(minutes=15), False), + (ClientError(), timedelta(seconds=20), True), + (ClientError(), timedelta(seconds=40), False), + ], +) async def test_connectivity_check_throttling( coresys: CoreSys, websession: AsyncMock, side_effect: Exception | None, - call_count: int, + call_interval: timedelta, + throttled: bool, ): """Test connectivity check throttled when checks succeed.""" coresys.supervisor.connectivity = None websession.head.side_effect = side_effect - for _ in range(3): + reset_last_call(Supervisor.check_connectivity) + with travel(datetime.now(), tick=False) as traveller: + await coresys.supervisor.check_connectivity() + traveller.shift(call_interval) await coresys.supervisor.check_connectivity() - assert websession.head.call_count == call_count + assert websession.head.call_count == (1 if throttled else 2) async def test_update_failed(coresys: CoreSys, capture_exception: Mock):