From 0b9da5a73a8b96e2256d0ea02a721b916a355adc Mon Sep 17 00:00:00 2001 From: mle Date: Mon, 20 May 2024 23:41:42 +0200 Subject: [PATCH] Make fragment handling more strict --- goodwe/protocol.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/goodwe/protocol.py b/goodwe/protocol.py index 37c8b70..d8d6a33 100644 --- a/goodwe/protocol.py +++ b/goodwe/protocol.py @@ -42,6 +42,7 @@ def __init__(self, host: str, port: int, comm_addr: int, timeout: int, retries: self.response_future: Future | None = None self.command: ProtocolCommand | None = None self._partial_data: bytes | None = None + self._partial_missing: int = 0 def _ensure_lock(self) -> asyncio.Lock: """Validate (or create) asyncio Lock. @@ -125,22 +126,21 @@ def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: self._timer.cancel() self._timer = None try: - if self._partial_data: - logger.debug("Received another response fragment: %s.", data.hex()) + if self._partial_data and self._partial_missing == len(data): + logger.debug("Composed fragmented response: %s", data.hex()) data = self._partial_data + data - if self.command.validator(data): - if self._partial_data: - logger.debug("Composed fragmented response: %s", data.hex()) - else: - logger.debug("Received: %s", data.hex()) self._partial_data = None + self._partial_missing = 0 + if self.command.validator(data): + logger.debug("Received: %s", data.hex()) self.response_future.set_result(data) else: logger.debug("Received invalid response: %s", data.hex()) asyncio.get_running_loop().call_soon(self._retry_mechanism) - except PartialResponseException: - logger.debug("Received response fragment: %s", data.hex()) + except PartialResponseException as ex: + logger.debug("Received response fragment (%d of %d): %s", ex.length, ex.expected, data.hex()) self._partial_data = data + self._partial_missing = ex.expected - ex.length return except asyncio.InvalidStateError: logger.debug("Response already handled: %s", data.hex()) @@ -161,6 +161,7 @@ async def send_request(self, command: ProtocolCommand) -> Future: await self._connect() response_future = asyncio.get_running_loop().create_future() self._retry = 0 + self._partial_data = None self._send_request(command, response_future) await response_future return response_future @@ -266,24 +267,23 @@ def data_received(self, data: bytes) -> None: if self._timer: self._timer.cancel() try: - if self._partial_data: - logger.debug("Received another response fragment: %s.", data.hex()) + if self._partial_data and self._partial_missing == len(data): + logger.debug("Composed fragmented response: %s", data.hex()) data = self._partial_data + data + self._partial_data = None + self._partial_missing = 0 if self.command.validator(data): - if self._partial_data: - logger.debug("Composed fragmented response: %s", data.hex()) - else: - logger.debug("Received: %s", data.hex()) + logger.debug("Received: %s", data.hex()) self._retry = 0 - self._partial_data = None self.response_future.set_result(data) else: logger.debug("Received invalid response: %s", data.hex()) self.response_future.set_exception(RequestRejectedException()) self._close_transport() - except PartialResponseException: - logger.debug("Received response fragment: %s", data.hex()) + except PartialResponseException as ex: + logger.debug("Received response fragment (%d of %d): %s", ex.length, ex.expected, data.hex()) self._partial_data = data + self._partial_missing = ex.expected - ex.length return except asyncio.InvalidStateError: logger.debug("Response already handled: %s", data.hex()) @@ -304,6 +304,7 @@ async def send_request(self, command: ProtocolCommand) -> Future: try: await asyncio.wait_for(self._connect(), timeout=5) response_future = asyncio.get_running_loop().create_future() + self._partial_data = None self._send_request(command, response_future) await response_future return response_future