From 266a3d085e6de898bc846fc26b2b0076fe42e50f Mon Sep 17 00:00:00 2001 From: Dos Moonen Date: Sat, 4 May 2024 23:35:05 +0200 Subject: [PATCH] Rework discovery timeout logic (#153) --- setup.py | 1 - solax/__init__.py | 17 ++-- solax/discovery.py | 164 ++++++++++++++++------------------ solax/inverter_http_client.py | 10 ++- 4 files changed, 98 insertions(+), 94 deletions(-) diff --git a/setup.py b/setup.py index 515608c..19e938f 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,6 @@ packages=setuptools.find_packages(exclude=["tests", "tests.*"]), install_requires=[ "aiohttp>=3.5.4, <4", - "async_timeout>=4.0.2", "voluptuous>=0.11.5", "importlib_metadata>=3.6; python_version<'3.10'", "typing_extensions>=4.1.0; python_version<'3.11'", diff --git a/solax/__init__.py b/solax/__init__.py index 1ac860c..b128fb7 100644 --- a/solax/__init__.py +++ b/solax/__init__.py @@ -3,15 +3,21 @@ import asyncio import logging -from async_timeout import timeout - from solax.discovery import discover from solax.inverter import Inverter, InverterResponse +from solax.inverter_http_client import REQUEST_TIMEOUT _LOGGER = logging.getLogger(__name__) - -REQUEST_TIMEOUT = 5 +__all__ = ( + "discover", + "real_time_api", + "rt_request", + "Inverter", + "InverterResponse", + "RealTimeAPI", + "REQUEST_TIMEOUT", +) async def rt_request(inv: Inverter, retry, t_wait=0) -> InverterResponse: @@ -23,8 +29,7 @@ async def rt_request(inv: Inverter, retry, t_wait=0) -> InverterResponse: new_wait = (t_wait * 2) + 5 retry = retry - 1 try: - async with timeout(REQUEST_TIMEOUT): - return await inv.get_data() + return await inv.get_data() except asyncio.TimeoutError: if retry > 0: return await rt_request(inv, retry, new_wait) diff --git a/solax/discovery.py b/solax/discovery.py index 77d144f..5928453 100644 --- a/solax/discovery.py +++ b/solax/discovery.py @@ -3,9 +3,7 @@ import sys from asyncio import Future, Task from collections import defaultdict -from typing import Dict, Literal, Optional, Sequence, Set, TypedDict, Union, cast - -from async_timeout import timeout +from typing import Dict, Literal, Sequence, Set, TypedDict, Union, cast from solax.inverter import Inverter from solax.inverter_http_client import InverterHttpClient @@ -29,7 +27,6 @@ class DiscoveryKeywords(TypedDict, total=False): - timeout: Optional[float] inverters: Sequence[Inverter] return_when: Union[Literal["ALL_COMPLETED"], Literal["FIRST_COMPLETED"]] @@ -72,89 +69,86 @@ async def _discovery_task(i) -> Inverter: async def discover( host, port, pwd="", **kwargs: Unpack[DiscoveryKeywords] ) -> Union[Inverter, Set[Inverter]]: - async with timeout(kwargs.get("timeout", 15)): - done: Set[_InverterTask] = set() - pending: Set[_InverterTask] = set() - failures = set() - requests: Dict[InverterHttpClient, Future] = defaultdict( - asyncio.get_running_loop().create_future - ) - - return_when = kwargs.get("return_when", asyncio.FIRST_COMPLETED) - for cls in kwargs.get("inverters", REGISTRY): - for inverter in cls.build_all_variants(host, port, pwd): - inverter.http_client = cast( - InverterHttpClient, - _DiscoveryHttpClient( - inverter, inverter.http_client, requests[inverter.http_client] - ), - ) - - pending.add( - asyncio.create_task(_discovery_task(inverter), name=f"{inverter}") - ) - - if not pending: - raise DiscoveryError("No inverters to try to discover") - - def cancel(pending: Set[_InverterTask]) -> Set[_InverterTask]: - for task in pending: - task.cancel() - return pending - - def remove_failures_from(done: Set[_InverterTask]) -> None: - for task in set(done): - exc = task.exception() - if exc: - failures.add(exc) - done.remove(task) - - # stagger HTTP request to prevent accidental Denial Of Service - async def stagger() -> None: - for http_client, future in requests.items(): - future.set_result(asyncio.create_task(http_client.request())) - await asyncio.sleep(1) - - staggered = asyncio.create_task(stagger()) - - while pending and (not done or return_when != asyncio.FIRST_COMPLETED): - try: - done, pending = await asyncio.wait(pending, return_when=return_when) - except asyncio.CancelledError: - staggered.cancel() - await asyncio.gather( - staggered, *cancel(pending), return_exceptions=True - ) - raise - - remove_failures_from(done) - - if done and return_when == asyncio.FIRST_COMPLETED: - break - - logging.debug("%d discovery tasks are still running...", len(pending)) - - if pending and return_when != asyncio.FIRST_COMPLETED: - pending.update(done) - done.clear() + done: Set[_InverterTask] = set() + pending: Set[_InverterTask] = set() + failures = set() + requests: Dict[InverterHttpClient, Future] = defaultdict( + asyncio.get_running_loop().create_future + ) + + return_when = kwargs.get("return_when", asyncio.FIRST_COMPLETED) + for cls in kwargs.get("inverters", REGISTRY): + for inverter in cls.build_all_variants(host, port, pwd): + inverter.http_client = cast( + InverterHttpClient, + _DiscoveryHttpClient( + inverter, inverter.http_client, requests[inverter.http_client] + ), + ) + + pending.add( + asyncio.create_task(_discovery_task(inverter), name=f"{inverter}") + ) + + if not pending: + raise DiscoveryError("No inverters to try to discover") + + def cancel(pending: Set[_InverterTask]) -> Set[_InverterTask]: + for task in pending: + task.cancel() + return pending + + def remove_failures_from(done: Set[_InverterTask]) -> None: + for task in set(done): + exc = task.exception() + if exc: + failures.add(exc) + done.remove(task) + + # stagger HTTP request to prevent accidental Denial Of Service + async def stagger() -> None: + for http_client, future in requests.items(): + future.set_result(asyncio.create_task(http_client.request())) + await asyncio.sleep(1) + + staggered = asyncio.create_task(stagger()) + + while pending and (not done or return_when != asyncio.FIRST_COMPLETED): + try: + done, pending = await asyncio.wait(pending, return_when=return_when) + except asyncio.CancelledError: + staggered.cancel() + await asyncio.gather(staggered, *cancel(pending), return_exceptions=True) + raise remove_failures_from(done) - staggered.cancel() - await asyncio.gather(staggered, *cancel(pending), return_exceptions=True) - - if done: - logging.info("Discovered inverters: %s", {task.result() for task in done}) - if return_when == asyncio.FIRST_COMPLETED: - return await next(iter(done)) - - return {task.result() for task in done} - - raise DiscoveryError( - "Unable to connect to the inverter at " - f"host={host} port={port}, or your inverter is not supported yet.\n" - "Please see https://github.com/squishykid/solax/wiki/DiscoveryError\n" - f"Failures={str(failures)}" - ) + + if done and return_when == asyncio.FIRST_COMPLETED: + break + + logging.debug("%d discovery tasks are still running...", len(pending)) + + if pending and return_when != asyncio.FIRST_COMPLETED: + pending.update(done) + done.clear() + + remove_failures_from(done) + staggered.cancel() + await asyncio.gather(staggered, *cancel(pending), return_exceptions=True) + + if done: + logging.info("Discovered inverters: %s", {task.result() for task in done}) + if return_when == asyncio.FIRST_COMPLETED: + return await next(iter(done)) + + return {task.result() for task in done} + + raise DiscoveryError( + "Unable to connect to the inverter at " + f"host={host} port={port}, or your inverter is not supported yet.\n" + "Please see https://github.com/squishykid/solax/wiki/DiscoveryError\n" + f"Failures={str(failures)}" + ) class DiscoveryError(Exception): diff --git a/solax/inverter_http_client.py b/solax/inverter_http_client.py index e68adce..bc6ed81 100644 --- a/solax/inverter_http_client.py +++ b/solax/inverter_http_client.py @@ -14,6 +14,8 @@ if sys.version_info >= (3, 10): from dataclasses import KW_ONLY + +REQUEST_TIMEOUT = 5.0 _CACHE: WeakValueDictionary[int, InverterHttpClient] = WeakValueDictionary() @@ -107,7 +109,9 @@ async def request(self): async def get(self): url = self.url + "?" + self.query if self.query else self.url async with aiohttp.ClientSession() as session: - async with session.get(url, headers=self.headers) as req: + async with session.get( + url, headers=self.headers, timeout=REQUEST_TIMEOUT + ) as req: req.raise_for_status() resp = await req.read() return resp @@ -116,7 +120,9 @@ async def post(self): url = self.url + "?" + self.query if self.query else self.url data = self.data.encode("utf-8") if self.data else None async with aiohttp.ClientSession() as session: - async with session.post(url, headers=self.headers, data=data) as req: + async with session.post( + url, headers=self.headers, data=data, timeout=REQUEST_TIMEOUT + ) as req: req.raise_for_status() resp = await req.read() return resp