From 20fc8099157855b00ada8cd077bd0a9591e62af1 Mon Sep 17 00:00:00 2001 From: Alexander Fabisch Date: Mon, 17 Jul 2023 11:08:56 +0200 Subject: [PATCH] Make TransformManager.transforms a property --- pytransform3d/transform_manager.py | 15 ++++++++++++--- pytransform3d/transform_manager.pyi | 18 +++++++++++++++--- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/pytransform3d/transform_manager.py b/pytransform3d/transform_manager.py index e5c434383..833652ad3 100644 --- a/pytransform3d/transform_manager.py +++ b/pytransform3d/transform_manager.py @@ -35,8 +35,6 @@ def __init__(self, strict_check=True, check=True): # Names of nodes self.nodes = [] - # Rigid transformations between nodes - self.transforms = {} # A pair (self.i[n], self.j[n]) represents indices of connected nodes self.i = [] @@ -53,6 +51,11 @@ def __init__(self, strict_check=True, check=True): self._cached_shortest_paths = {} + @property + @abc.abstractmethod + def transforms(self): + """Rigid transformations between nodes.""" + def has_frame(self, frame): """Check if frame has been registered. @@ -307,6 +310,12 @@ class TransformManager(TransformGraphBase): """ def __init__(self, strict_check=True, check=True): super(TransformManager, self).__init__(strict_check, check) + self._transforms = {} + + @property + def transforms(self): + """Rigid transformations between nodes.""" + return self._transforms def _check_transform(self, A2B): return check_transform(A2B, strict_check=self.strict_check) @@ -583,7 +592,7 @@ def set_transform_manager_state(self, tm_dict): Serializable dict. """ transforms = tm_dict.get("transforms") - self.transforms = dict([ + self._transforms = dict([ (tuple(k), np.array(v).reshape(4, 4)) for k, v in transforms]) self.nodes = tm_dict.get("nodes") self.i = tm_dict.get("i") diff --git a/pytransform3d/transform_manager.pyi b/pytransform3d/transform_manager.pyi index d33be2217..3fa9a6766 100644 --- a/pytransform3d/transform_manager.pyi +++ b/pytransform3d/transform_manager.pyi @@ -11,7 +11,6 @@ class TransformGraphBase(abc.ABC): strict_check: bool check: bool nodes: List[Hashable] - transforms: Dict[Tuple[Hashable, Hashable], np.ndarray] i: List[int] j: List[int] transform_to_ij_index = Dict[Tuple[Hashable, Hashable], int] @@ -20,6 +19,13 @@ class TransformGraphBase(abc.ABC): predecessors: np.ndarray _cached_shortest_paths: Dict[Tuple[int, int], List[Hashable]] + def __init__(self, strict_check: bool = ..., + check: bool = ...) -> "TransformGraphBase": ... + + @property + @abc.abstractmethod + def transforms(self) -> Dict[Tuple[Hashable, Hashable], np.ndarray]: ... + def has_frame(self, frame: Hashable) -> bool: ... def add_transform(self, from_frame: Hashable, to_frame: Hashable, @@ -31,7 +37,7 @@ class TransformGraphBase(abc.ABC): def remove_transform( self, from_frame: Hashable, - to_frame: Hashable) -> "TransformManager": ... + to_frame: Hashable) -> "TransformGraphBase": ... def get_transform( self, from_frame: Hashable, to_frame: Hashable) -> Any: ... @@ -50,7 +56,13 @@ class TransformGraphBase(abc.ABC): class TransformManager(TransformGraphBase): - def __init__(self, strict_check: bool = ..., check: bool = ...): ... + _transforms: Dict[Tuple[Hashable, Hashable], np.ndarray] + + def __init__(self, strict_check: bool = ..., + check: bool = ...) -> "TransformManager": ... + + @property + def transforms(self) -> Dict[Tuple[Hashable, Hashable], np.ndarray]: ... def add_transform(self, from_frame: Hashable, to_frame: Hashable, A2B: np.ndarray) -> "TransformManager": ...