Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints #71

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions pyotgw/commandprocessor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
"""OpenTherm Gateway command handler."""

from __future__ import annotations

import asyncio
import logging
import re
from asyncio.queues import QueueFull
from typing import TYPE_CHECKING

from pyotgw import vars as v

if TYPE_CHECKING:
from pyotgw.status import StatusManager

_LOGGER = logging.getLogger(__name__)


Expand All @@ -15,16 +21,18 @@ class CommandProcessor:

def __init__(
self,
protocol,
status_manager,
):
protocol: asyncio.Protocol,
status_manager: StatusManager,
) -> None:
"""Initialise the CommandProcessor object."""
self.protocol = protocol
self._lock = asyncio.Lock()
self._cmdq = asyncio.Queue()
self.status_manager = status_manager

async def issue_cmd(self, cmd, value, retry=3):
async def issue_cmd(
self, cmd: str, value: str | float | int, retry: int = 3
) -> bool | str | list[str] | None:
"""
Issue a command, then await and return the return value.

Expand Down Expand Up @@ -93,15 +101,15 @@ async def process(msg):
if ret is not None:
return ret

def clear_queue(self):
def clear_queue(self) -> None:
"""Clear leftover messages from the command queue"""
while not self._cmdq.empty():
_LOGGER.debug(
"Clearing leftover message from command queue: %s",
self._cmdq.get_nowait(),
)

def submit_response(self, response):
def submit_response(self, response: str) -> None:
"""Add a possible response to the command queue"""
try:
self._cmdq.put_nowait(response)
Expand All @@ -110,7 +118,7 @@ def submit_response(self, response):
_LOGGER.error("Queue full, discarded message: %s", response)

@staticmethod
def _get_expected_response(cmd, value):
def _get_expected_response(cmd: str, value: str | int) -> str:
"""Return the expected response pattern"""
if cmd == v.OTGW_CMD_REPORT:
return rf"^{cmd}:\s*([A-Z]{{2}}|{value}=[^$]+)$"
Expand Down
86 changes: 59 additions & 27 deletions pyotgw/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,22 @@
to the gateway goes here.
"""

from __future__ import annotations

import asyncio
import logging
from dataclasses import dataclass
from functools import partial
from typing import Callable, Literal, Optional, TYPE_CHECKING

import serial
import serial_asyncio_fast

from pyotgw.protocol import OpenThermProtocol

if TYPE_CHECKING:
from pyotgw import OpenThermGateway

CONNECTION_TIMEOUT = 5

MAX_RETRY_TIMEOUT = 60
Expand All @@ -23,10 +30,34 @@
_LOGGER = logging.getLogger(__name__)


@dataclass
class ConnectionConfig:
"""Config for the serial connection."""

baudrate: Optional[int]
bytesize: Optional[
Literal[serial.FIVEBITS, serial.SIXBITS, serial.SEVENBITS, serial.EIGHTBITS]
]
parity: Optional[
Literal[
serial.PARITY_NONE,
serial.PARITY_EVEN,
serial.PARITY_ODD,
serial.PARITY_MARK,
serial.PARITY_SPACE,
]
]
stopbits: Optional[
Literal[
serial.STOPBITS_ONE, serial.STOPBITS_ONE_POINT_FIVE, serial.STOPBITS_TWO
]
]


class ConnectionManager: # pylint: disable=too-many-instance-attributes
"""Functionality for setting up and tearing down a connection"""

def __init__(self, otgw):
def __init__(self, otgw: OpenThermGateway) -> None:
"""Initialise the connection manager"""
self._error = None
self._port = None
Expand All @@ -44,7 +75,7 @@ def __init__(self, otgw):
self._transport = None
self.protocol = None

async def connect(self, port, timeout=None):
async def connect(self, port: int, timeout: asyncio.Timeout = None) -> bool:
"""Start connection attempts. Return True on success or False on failure."""
if self.connected or self._connecting_task:
# We are actually reconnecting, cleanup first.
Expand All @@ -68,13 +99,13 @@ async def connect(self, port, timeout=None):
self.watchdog.start(self.reconnect, timeout=timeout or WATCHDOG_TIMEOUT)
return True

async def disconnect(self):
async def disconnect(self) -> None:
"""Disconnect from the OpenTherm Gateway."""
await self._cleanup()
if self.connected:
self.protocol.disconnect()

async def reconnect(self):
async def reconnect(self) -> None:
"""Reconnect to the OpenTherm Gateway."""
if not self._port:
_LOGGER.error("Reconnect called before connect!")
Expand All @@ -84,11 +115,11 @@ async def reconnect(self):
await self._otgw.connect(self._port)

@property
def connected(self):
def connected(self) -> bool:
"""Return the connection status"""
return self.protocol and self.protocol.connected

def set_connection_config(self, **kwargs):
def set_connection_config(self, **kwargs: ConnectionConfig) -> bool:
"""
Set the serial connection parameters before calling connect()
Valid kwargs are 'baudrate', 'bytesize', 'parity' and 'stopbits'.
Expand All @@ -104,26 +135,27 @@ def set_connection_config(self, **kwargs):
self._config.update(kwargs)
return True

async def _attempt_connect(self):
async def _attempt_connect(self) -> tuple[asyncio.Transport, asyncio.Protocol]:
"""Try to connect to the OpenTherm Gateway."""
loop = asyncio.get_running_loop()
transport = None
protocol = None
self._retry_timeout = MIN_RETRY_TIMEOUT
while transport is None:
try:
transport, protocol = await (
serial_asyncio_fast.create_serial_connection(
loop,
partial(
OpenThermProtocol,
self._otgw.status,
self.watchdog.inform,
),
self._port,
write_timeout=0,
**self._config,
)
(
transport,
protocol,
) = await serial_asyncio_fast.create_serial_connection(
loop,
partial(
OpenThermProtocol,
self._otgw.status,
self.watchdog.inform,
),
self._port,
write_timeout=0,
**self._config,
)
await asyncio.wait_for(
protocol.init_and_wait_for_activity(),
Expand Down Expand Up @@ -164,7 +196,7 @@ async def _attempt_connect(self):
transport = None
await asyncio.sleep(self._get_retry_timeout())

async def _cleanup(self):
async def _cleanup(self) -> None:
"""Cleanup possible leftovers from old connections"""
await self.watchdog.stop()
if self.protocol:
Expand All @@ -176,7 +208,7 @@ async def _cleanup(self):
except asyncio.CancelledError:
self._connecting_task = None

def _get_retry_timeout(self):
def _get_retry_timeout(self) -> asyncio.Timeout:
"""Increase if needed and return the retry timeout."""
if self._retry_timeout == MAX_RETRY_TIMEOUT:
return self._retry_timeout
Expand All @@ -188,7 +220,7 @@ def _get_retry_timeout(self):
class ConnectionWatchdog:
"""Connection watchdog"""

def __init__(self):
def __init__(self) -> None:
"""Initialise the object"""
self._callback = None
self.timeout = WATCHDOG_TIMEOUT
Expand All @@ -197,11 +229,11 @@ def __init__(self):
self.loop = asyncio.get_event_loop()

@property
def is_active(self):
def is_active(self) -> bool:
"""Return watchdog status"""
return self._wd_task is not None

async def inform(self):
async def inform(self) -> None:
"""Reset the watchdog timer."""
async with self._lock:
if not self.is_active:
Expand All @@ -215,7 +247,7 @@ async def inform(self):
self._wd_task = self.loop.create_task(self._watchdog(self.timeout))
_LOGGER.debug("Watchdog timer reset!")

def start(self, callback, timeout):
def start(self, callback: Callable[[], None], timeout: asyncio.Timeout) -> bool:
"""Start the watchdog, return boolean indicating success"""
if self.is_active:
return False
Expand All @@ -224,7 +256,7 @@ def start(self, callback, timeout):
self._wd_task = self.loop.create_task(self._watchdog(timeout))
return self.is_active

async def stop(self):
async def stop(self) -> None:
"""Stop the watchdog"""
async with self._lock:
if not self.is_active:
Expand All @@ -236,7 +268,7 @@ async def stop(self):
except asyncio.CancelledError:
self._wd_task = None

async def _watchdog(self, timeout):
async def _watchdog(self, timeout: asyncio.Timeout) -> None:
"""Trigger and cancel the watchdog after timeout. Schedule callback."""
await asyncio.sleep(timeout)
_LOGGER.debug("Watchdog triggered!")
Expand Down
Loading