diff --git a/goodwe/inverter.py b/goodwe/inverter.py index c0b962f..b9df4b4 100644 --- a/goodwe/inverter.py +++ b/goodwe/inverter.py @@ -91,6 +91,7 @@ class Inverter(ABC): def __init__(self, host: str, port: int, comm_addr: int = 0, timeout: int = 1, retries: int = 3): self._protocol: InverterProtocol = self._create_protocol(host, port, comm_addr, timeout, retries) self._consecutive_failures_count: int = 0 + self.keep_alive: bool = True self.model_name: str | None = None self.serial_number: str | None = None @@ -129,6 +130,9 @@ async def _read_from_socket(self, command: ProtocolCommand) -> ProtocolResponse: except RequestFailedException as ex: self._consecutive_failures_count += 1 raise RequestFailedException(ex.message, self._consecutive_failures_count) from None + finally: + if not self.keep_alive: + self._protocol.close_transport() @abstractmethod async def read_device_info(self): diff --git a/goodwe/protocol.py b/goodwe/protocol.py index be15329..1c1da9a 100644 --- a/goodwe/protocol.py +++ b/goodwe/protocol.py @@ -3,6 +3,8 @@ import asyncio import io import logging +import platform +import socket from asyncio.futures import Future from typing import Tuple, Optional, Callable @@ -54,13 +56,15 @@ def _ensure_lock(self) -> asyncio.Lock: logger.debug("Creating lock instance for current event loop.") self._lock = asyncio.Lock() self._running_loop = asyncio.get_event_loop() - self._close_transport() + self.close_transport() return self._lock - def _close_transport(self) -> None: + def close_transport(self) -> None: + """Close the underlying transport/connection.""" raise NotImplementedError() async def send_request(self, command: ProtocolCommand) -> Future: + """Convert command to request and send it to inverter.""" raise NotImplementedError() def read_command(self, offset: int, count: int) -> ProtocolCommand: @@ -111,7 +115,7 @@ def connection_lost(self, exc: Optional[Exception]) -> None: logger.debug("Socket closed with error: %s.", exc) else: logger.debug("Socket closed.") - self._close_transport() + self.close_transport() def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: """On datagram received""" @@ -130,13 +134,13 @@ def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: except RequestRejectedException as ex: logger.debug("Received exception response: %s", data.hex()) self.response_future.set_exception(ex) - self._close_transport() + self.close_transport() def error_received(self, exc: Exception) -> None: """On error received""" logger.debug("Received error: %s", exc) self.response_future.set_exception(exc) - self._close_transport() + self.close_transport() async def send_request(self, command: ProtocolCommand) -> Future: """Send message via transport""" @@ -172,9 +176,9 @@ def _retry_mechanism(self) -> None: else: logger.debug("Max number of retries (%d) reached, request %s failed.", self.retries, self.command) self.response_future.set_exception(MaxRetriesException) - self._close_transport() + self.close_transport() - def _close_transport(self) -> None: + def close_transport(self) -> None: if self._transport: try: self._transport.close() @@ -211,6 +215,14 @@ async def _connect(self) -> None: lambda: self, host=self._host, port=self._port, ) + sock = self._transport.get_extra_info('socket') + if sock is not None: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 10) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) + if platform.system() == 'Windows': + sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, 10000, 10000)) def connection_made(self, transport: asyncio.DatagramTransport) -> None: """On connection made""" @@ -219,7 +231,7 @@ def connection_made(self, transport: asyncio.DatagramTransport) -> None: def eof_received(self) -> None: logger.debug("EOF received.") - self._close_transport() + self.close_transport() def connection_lost(self, exc: Optional[Exception]) -> None: """On connection lost""" @@ -227,7 +239,7 @@ def connection_lost(self, exc: Optional[Exception]) -> None: logger.debug("Connection closed with error: %s.", exc) else: logger.debug("Connection closed.") - self._close_transport() + self.close_transport() def data_received(self, data: bytes) -> None: """On data received""" @@ -241,19 +253,19 @@ def data_received(self, data: bytes) -> None: else: logger.debug("Received invalid response: %s", data.hex()) self.response_future.set_exception(RequestRejectedException()) - self._close_transport() + self.close_transport() except asyncio.InvalidStateError: logger.debug("Response already handled: %s", data.hex()) except RequestRejectedException as ex: logger.debug("Received exception response: %s", data.hex()) self.response_future.set_exception(ex) - # self._close_transport() + # self.close_transport() def error_received(self, exc: Exception) -> None: """On error received""" logger.debug("Received error: %s", exc) self.response_future.set_exception(exc) - self._close_transport() + self.close_transport() async def send_request(self, command: ProtocolCommand) -> Future: """Send message via transport""" @@ -271,7 +283,7 @@ async def send_request(self, command: ProtocolCommand) -> Future: self._retry += 1 if self._lock and self._lock.locked(): self._lock.release() - self._close_transport() + self.close_transport() return await self.send_request(command) else: return self._max_retries_reached() @@ -308,16 +320,16 @@ def _timeout_mechanism(self) -> None: if self._timer: logger.debug("Failed to receive response to %s in time (%ds).", self.command, self.timeout) self._timer = None - self._close_transport() + self.close_transport() def _max_retries_reached(self) -> Future: logger.debug("Max number of retries (%d) reached, request %s failed.", self.retries, self.command) - self._close_transport() + self.close_transport() self.response_future = asyncio.get_running_loop().create_future() self.response_future.set_exception(MaxRetriesException) return self.response_future - def _close_transport(self) -> None: + def close_transport(self) -> None: if self._transport: try: self._transport.close() diff --git a/tests/stability_check.py b/tests/stability_check.py index 23d41a4..0a764f9 100644 --- a/tests/stability_check.py +++ b/tests/stability_check.py @@ -30,9 +30,10 @@ async def get_runtime_data(): - i = 1 - inverter = await goodwe.connect('127.0.0.1', 502) + inverter = await goodwe.connect(host='127.0.0.1', port=502, timeout=1, retries=3) + # inverter.keep_alive = False + i = 1 while True: logger.info("################################") logger.info(" Request %d", i)