Skip to content

Commit

Permalink
Make TransformManager.transforms a property
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFabisch committed Jul 17, 2023
1 parent f7a2018 commit 20fc809
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
15 changes: 12 additions & 3 deletions pytransform3d/transform_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def __init__(self, strict_check=True, check=True):

# Names of nodes
self.nodes = []
# Rigid transformations between nodes
self.transforms = {}

# A pair (self.i[n], self.j[n]) represents indices of connected nodes
self.i = []
Expand All @@ -53,6 +51,11 @@ def __init__(self, strict_check=True, check=True):

self._cached_shortest_paths = {}

@property
@abc.abstractmethod
def transforms(self):
"""Rigid transformations between nodes."""

def has_frame(self, frame):
"""Check if frame has been registered.
Expand Down Expand Up @@ -307,6 +310,12 @@ class TransformManager(TransformGraphBase):
"""
def __init__(self, strict_check=True, check=True):
super(TransformManager, self).__init__(strict_check, check)
self._transforms = {}

@property
def transforms(self):
"""Rigid transformations between nodes."""
return self._transforms

def _check_transform(self, A2B):
return check_transform(A2B, strict_check=self.strict_check)
Expand Down Expand Up @@ -583,7 +592,7 @@ def set_transform_manager_state(self, tm_dict):
Serializable dict.
"""
transforms = tm_dict.get("transforms")
self.transforms = dict([
self._transforms = dict([
(tuple(k), np.array(v).reshape(4, 4)) for k, v in transforms])
self.nodes = tm_dict.get("nodes")
self.i = tm_dict.get("i")
Expand Down
18 changes: 15 additions & 3 deletions pytransform3d/transform_manager.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class TransformGraphBase(abc.ABC):
strict_check: bool
check: bool
nodes: List[Hashable]
transforms: Dict[Tuple[Hashable, Hashable], np.ndarray]
i: List[int]
j: List[int]
transform_to_ij_index = Dict[Tuple[Hashable, Hashable], int]
Expand All @@ -20,6 +19,13 @@ class TransformGraphBase(abc.ABC):
predecessors: np.ndarray
_cached_shortest_paths: Dict[Tuple[int, int], List[Hashable]]

def __init__(self, strict_check: bool = ...,
check: bool = ...) -> "TransformGraphBase": ...

@property
@abc.abstractmethod
def transforms(self) -> Dict[Tuple[Hashable, Hashable], np.ndarray]: ...

def has_frame(self, frame: Hashable) -> bool: ...

def add_transform(self, from_frame: Hashable, to_frame: Hashable,
Expand All @@ -31,7 +37,7 @@ class TransformGraphBase(abc.ABC):

def remove_transform(
self, from_frame: Hashable,
to_frame: Hashable) -> "TransformManager": ...
to_frame: Hashable) -> "TransformGraphBase": ...

def get_transform(
self, from_frame: Hashable, to_frame: Hashable) -> Any: ...
Expand All @@ -50,7 +56,13 @@ class TransformGraphBase(abc.ABC):


class TransformManager(TransformGraphBase):
def __init__(self, strict_check: bool = ..., check: bool = ...): ...
_transforms: Dict[Tuple[Hashable, Hashable], np.ndarray]

def __init__(self, strict_check: bool = ...,
check: bool = ...) -> "TransformManager": ...

@property
def transforms(self) -> Dict[Tuple[Hashable, Hashable], np.ndarray]: ...

def add_transform(self, from_frame: Hashable, to_frame: Hashable,
A2B: np.ndarray) -> "TransformManager": ...
Expand Down

0 comments on commit 20fc809

Please sign in to comment.