Skip to content

Commit

Permalink
Add mypy support and fixup project to give no errors
Browse files Browse the repository at this point in the history
* Permissive mypy configuration as starting point
* Add minimal type annotations to get no mypy errors
* Add runtime test for self.network before using the network
* Network.add_node() doesn't accept LocalNode
* PeriodicMessageTask.update() don't stop the task unless its running
* Variable.desc ensure that the object is int
* Variable.read() fail with ValueError unless a valid fmt is used
* Variable.write() ensure the description is a string
* BaseNode.__init__() fail if no node_id is provided
* ObjectDictionary.__getitem__() when splitting "." only return if the object is not an ODVariable
* ODRecord.__eq__(), ODArray.__eq__() and ODVariable.__eq__() test type of other before comparing
* ODVariable.encode_raw(), .decode_phys(), .encode_phys() add type tests of ensure the input is of correct type
* PdoMap various methods: ensure necessary attributes are set
  • Loading branch information
sveinse committed Jul 10, 2024
1 parent 3aa509d commit 3b66ae5
Show file tree
Hide file tree
Showing 11 changed files with 204 additions and 83 deletions.
7 changes: 6 additions & 1 deletion canopen/emcy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import struct
import logging
import threading
Expand Down Expand Up @@ -52,7 +53,7 @@ def reset(self):

def wait(
self, emcy_code: Optional[int] = None, timeout: float = 10
) -> "EmcyError":
) -> Optional[EmcyError]:
"""Wait for a new EMCY to arrive.
:param emcy_code: EMCY code to wait for
Expand Down Expand Up @@ -86,10 +87,14 @@ def __init__(self, cob_id: int):
self.cob_id = cob_id

def send(self, code: int, register: int = 0, data: bytes = b""):
if self.network is None:
raise RuntimeError("A Network is required")
payload = EMCY_STRUCT.pack(code, register, data)
self.network.send_message(self.cob_id, payload)

def reset(self, register: int = 0, data: bytes = b""):
if self.network is None:
raise RuntimeError("A Network is required")
payload = EMCY_STRUCT.pack(0, register, data)
self.network.send_message(self.cob_id, payload)

Expand Down
54 changes: 33 additions & 21 deletions canopen/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
from collections.abc import MutableMapping
import logging
import threading
from typing import Callable, Dict, Iterator, List, Optional, Union
from typing import Callable, Dict, Iterator, List, Optional, Union, TYPE_CHECKING, TextIO

try:
import can
from can import Listener
from can import CanError
except ImportError:
# Do not fail if python-can is not installed
can = None
CanError = Exception
class Listener:
""" Dummy listener """
# Type checkers don't like this conditional logic, so it is only run when
# not type checking
if not TYPE_CHECKING:
# Do not fail if python-can is not installed
can = None
CanError = Exception
class Listener:
""" Dummy listener """

from canopen.node import RemoteNode, LocalNode
from canopen.sync import SyncProducer
Expand All @@ -24,6 +27,9 @@ class Listener:
from canopen.objectdictionary.eds import import_from_node
from canopen.objectdictionary import ObjectDictionary

if TYPE_CHECKING:
from can.typechecking import CanData

logger = logging.getLogger(__name__)

Callback = Callable[[int, bytearray, float], None]
Expand All @@ -45,7 +51,7 @@ def __init__(self, bus: Optional[can.BusABC] = None):
#: List of :class:`can.Listener` objects.
#: Includes at least MessageListener.
self.listeners = [MessageListener(self)]
self.notifier = None
self.notifier: Optional[can.Notifier] = None
self.nodes: Dict[int, Union[RemoteNode, LocalNode]] = {}
self.subscribers: Dict[int, List[Callback]] = {}
self.send_lock = threading.Lock()
Expand Down Expand Up @@ -138,15 +144,15 @@ def __exit__(self, type, value, traceback):

def add_node(
self,
node: Union[int, RemoteNode, LocalNode],
object_dictionary: Union[str, ObjectDictionary, None] = None,
node: Union[int, RemoteNode],
object_dictionary: Union[str, ObjectDictionary, TextIO, None] = None,
upload_eds: bool = False,
) -> RemoteNode:
"""Add a remote node to the network.
:param node:
Can be either an integer representing the node ID, a
:class:`canopen.RemoteNode` or :class:`canopen.LocalNode` object.
:class:`canopen.RemoteNode` object.
:param object_dictionary:
Can be either a string for specifying the path to an
Object Dictionary file or a
Expand All @@ -161,14 +167,16 @@ def add_node(
if upload_eds:
logger.info("Trying to read EDS from node %d", node)
object_dictionary = import_from_node(node, self)
node = RemoteNode(node, object_dictionary)
self[node.id] = node
return node
nodeobj = RemoteNode(node, object_dictionary)
else:
nodeobj = node
self[nodeobj.id] = nodeobj
return nodeobj

def create_node(
self,
node: int,
object_dictionary: Union[str, ObjectDictionary, None] = None,
object_dictionary: Union[str, ObjectDictionary, TextIO, None] = None,
) -> LocalNode:
"""Create a local node in the network.
Expand All @@ -183,11 +191,13 @@ def create_node(
The Node object that was added.
"""
if isinstance(node, int):
node = LocalNode(node, object_dictionary)
self[node.id] = node
return node
nodeobj = LocalNode(node, object_dictionary)
else:
nodeobj = node
self[nodeobj.id] = nodeobj
return nodeobj

def send_message(self, can_id: int, data: bytes, remote: bool = False) -> None:
def send_message(self, can_id: int, data: CanData, remote: bool = False) -> None:
"""Send a raw CAN message to the network.
This method may be overridden in a subclass if you need to integrate
Expand Down Expand Up @@ -215,7 +225,7 @@ def send_message(self, can_id: int, data: bytes, remote: bool = False) -> None:
self.check()

def send_periodic(
self, can_id: int, data: bytes, period: float, remote: bool = False
self, can_id: int, data: CanData, period: float, remote: bool = False
) -> PeriodicMessageTask:
"""Start sending a message periodically.
Expand Down Expand Up @@ -295,7 +305,7 @@ class PeriodicMessageTask:
def __init__(
self,
can_id: int,
data: bytes,
data: CanData,
period: float,
bus,
remote: bool = False,
Expand Down Expand Up @@ -335,10 +345,12 @@ def update(self, data: bytes) -> None:
old_data = self.msg.data
self.msg.data = new_data
if hasattr(self._task, "modify_data"):
assert self._task is not None # This will never be None, but mypy needs this
self._task.modify_data(self.msg)
elif new_data != old_data:
# Stop and start (will mess up period unfortunately)
self._task.stop()
if self._task is not None:
self._task.stop()
self._start()


Expand Down
23 changes: 17 additions & 6 deletions canopen/nmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import logging
import struct
import time
from typing import Callable, Optional
from typing import Callable, Optional, List, TYPE_CHECKING

if TYPE_CHECKING:
from canopen.network import Network, PeriodicMessageTask

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -45,7 +48,7 @@ class NmtBase:

def __init__(self, node_id: int):
self.id = node_id
self.network = None
self.network: Optional[Network] = None
self._state = 0

def on_command(self, can_id, data, timestamp):
Expand Down Expand Up @@ -107,11 +110,11 @@ class NmtMaster(NmtBase):
def __init__(self, node_id: int):
super(NmtMaster, self).__init__(node_id)
self._state_received = None
self._node_guarding_producer = None
self._node_guarding_producer: Optional[PeriodicMessageTask] = None
#: Timestamp of last heartbeat message
self.timestamp: Optional[float] = None
self.state_update = threading.Condition()
self._callbacks = []
self._callbacks: List[Callable[[int], None]] = []

def on_heartbeat(self, can_id, data, timestamp):
with self.state_update:
Expand Down Expand Up @@ -139,6 +142,8 @@ def send_command(self, code: int):
super(NmtMaster, self).send_command(code)
logger.info(
"Sending NMT command 0x%X to node %d", code, self.id)
if self.network is None:
raise RuntimeError("A Network is required")
self.network.send_message(0, [code, self.id])

def wait_for_heartbeat(self, timeout: float = 10):
Expand Down Expand Up @@ -181,7 +186,9 @@ def start_node_guarding(self, period: float):
Period (in seconds) at which the node guarding should be advertised to the slave node.
"""
if self._node_guarding_producer : self.stop_node_guarding()
self._node_guarding_producer = self.network.send_periodic(0x700 + self.id, None, period, True)
if self.network is None:
raise RuntimeError("A Network is required")
self._node_guarding_producer = self.network.send_periodic(0x700 + self.id, [], period, True)

def stop_node_guarding(self):
"""Stops the node guarding mechanism."""
Expand All @@ -197,7 +204,7 @@ class NmtSlave(NmtBase):

def __init__(self, node_id: int, local_node):
super(NmtSlave, self).__init__(node_id)
self._send_task = None
self._send_task: Optional[PeriodicMessageTask] = None
self._heartbeat_time_ms = 0
self._local_node = local_node

Expand All @@ -216,6 +223,8 @@ def send_command(self, code: int) -> None:

if self._state == 0:
logger.info("Sending boot-up message")
if self.network is None:
raise RuntimeError("A Network is required")
self.network.send_message(0x700 + self.id, [0])

# The heartbeat service should start on the transition
Expand Down Expand Up @@ -246,6 +255,8 @@ def start_heartbeat(self, heartbeat_time_ms: int):
self.stop_heartbeat()
if heartbeat_time_ms > 0:
logger.info("Start the heartbeat timer, interval is %d ms", self._heartbeat_time_ms)
if self.network is None:
raise RuntimeError("A network is required")
self._send_task = self.network.send_periodic(
0x700 + self.id, [self._state], heartbeat_time_ms / 1000.0)

Expand Down
11 changes: 7 additions & 4 deletions canopen/node/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TextIO, Union
from typing import TextIO, Union, Optional
from canopen.objectdictionary import ObjectDictionary, import_od


Expand All @@ -14,13 +14,16 @@ class BaseNode:

def __init__(
self,
node_id: int,
object_dictionary: Union[ObjectDictionary, str, TextIO],
node_id: Optional[int],
object_dictionary: Union[ObjectDictionary, str, TextIO, None],
):
self.network = None

if not isinstance(object_dictionary, ObjectDictionary):
object_dictionary = import_od(object_dictionary, node_id)
self.object_dictionary = object_dictionary

self.id = node_id or self.object_dictionary.node_id
node_id = node_id or self.object_dictionary.node_id
if node_id is None:
raise ValueError("Node ID must be specified")
self.id: int = node_id
28 changes: 22 additions & 6 deletions canopen/node/local.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,45 @@
import logging
from typing import Dict, Union
from typing import Dict, Union, List, Protocol, TextIO, Optional

from canopen.node.base import BaseNode
from canopen.sdo import SdoServer, SdoAbortedError
from canopen.pdo import PDO, TPDO, RPDO
from canopen.nmt import NmtSlave
from canopen.emcy import EmcyProducer
from canopen.objectdictionary import ObjectDictionary
from canopen.objectdictionary import ObjectDictionary, ODVariable
from canopen import objectdictionary

logger = logging.getLogger(__name__)


class WriteCallback(Protocol):
"""LocalNode Write Callback Protocol"""
def __call__(self, *, index: int, subindex: int,
od: ODVariable,
data: bytes) -> None:
''' Write Callback '''


class ReadCallback(Protocol):
"""LocalNode Read Callback Protocol"""
def __call__(self, *, index: int, subindex: int,
od: ODVariable
) -> Union[bool, int, float, str, bytes, None]:
''' Read Callback '''


class LocalNode(BaseNode):

def __init__(
self,
node_id: int,
object_dictionary: Union[ObjectDictionary, str],
node_id: Optional[int],
object_dictionary: Union[ObjectDictionary, str, TextIO, None],
):
super(LocalNode, self).__init__(node_id, object_dictionary)

self.data_store: Dict[int, Dict[int, bytes]] = {}
self._read_callbacks = []
self._write_callbacks = []
self._read_callbacks: List[ReadCallback] = []
self._write_callbacks: List[WriteCallback] = []

self.sdo = SdoServer(0x600 + self.id, 0x580 + self.id, self)
self.tpdo = TPDO(self)
Expand Down
8 changes: 4 additions & 4 deletions canopen/node/remote.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Union, TextIO
from typing import Union, TextIO, List, Optional

from canopen.sdo import SdoClient, SdoCommunicationError, SdoAbortedError
from canopen.nmt import NmtMaster
Expand All @@ -26,16 +26,16 @@ class RemoteNode(BaseNode):

def __init__(
self,
node_id: int,
object_dictionary: Union[ObjectDictionary, str, TextIO],
node_id: Optional[int],
object_dictionary: Union[ObjectDictionary, str, TextIO, None],
load_od: bool = False,
):
super(RemoteNode, self).__init__(node_id, object_dictionary)

#: Enable WORKAROUND for reversed PDO mapping entries
self.curtis_hack = False

self.sdo_channels = []
self.sdo_channels: List[SdoClient] = []
self.sdo = self.add_sdo(0x600 + self.id, 0x580 + self.id)
self.tpdo = TPDO(self)
self.rpdo = RPDO(self)
Expand Down
Loading

0 comments on commit 3b66ae5

Please sign in to comment.