Skip to content

Commit

Permalink
Cache commonly used properties
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Aug 23, 2024
1 parent f6f7723 commit 2feb8b0
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 58 deletions.
21 changes: 12 additions & 9 deletions src/rod/kinematics/kinematic_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,15 @@ def __post_init__(self):
self.joints.sort(key=lambda j: j.index)
self.frames.sort(key=lambda f: f.index)

@functools.cached_property
def link_names(self) -> list[str]:
return [node.name() for node in self]

@functools.cached_property
def frame_names(self) -> list[str]:
return [frame.name() for frame in self.frames]

@functools.cached_property
def joint_names(self) -> list[str]:
return [joint.name() for joint in self.joints]

Expand Down Expand Up @@ -75,7 +78,7 @@ def build(model: rod.Model, is_top_level: bool = True) -> KinematicTree:
# Create a dict mapping link names to tree nodes, for easy retrieval.
nodes_links_dict: dict[str, DirectedTreeNode] = {
# Add one node for each link of the model
**{link.name: DirectedTreeNode(_source=link) for link in model.links()},
**{link.name: DirectedTreeNode(_source=link) for link in model.links},
# Add special world node, that will become a frame later
TreeFrame.WORLD: DirectedTreeNode(
_source=rod.Link(
Expand Down Expand Up @@ -103,7 +106,7 @@ def build(model: rod.Model, is_top_level: bool = True) -> KinematicTree:
# Create a dict mapping frame names to frame nodes, for easy retrieval.
nodes_frames_dict: dict[str, TreeFrame] = {
# Add a frame node for each frame in the model
**{frame.name: TreeFrame(_source=frame) for frame in model.frames()},
**{frame.name: TreeFrame(_source=frame) for frame in model.frames},
# Add implicit frames used in the SDF specification (__model__).
# The following frames are attached to the first link found in the model
# description and never moved, so that all elements expressing their pose
Expand All @@ -123,7 +126,7 @@ def build(model: rod.Model, is_top_level: bool = True) -> KinematicTree:
) == (len(nodes_links_dict) + len(nodes_frames_dict))

# Use joints to connect nodes by defining their parent and children
for joint in model.joints():
for joint in model.joints:
if joint.child == TreeFrame.WORLD:
msg = f"A joint cannot have '{TreeFrame.WORLD}' as child"
raise RuntimeError(msg)
Expand Down Expand Up @@ -161,30 +164,30 @@ def build(model: rod.Model, is_top_level: bool = True) -> KinematicTree:
# Get all the joints part of the kinematic tree ...
joints_in_tree_names = [
j.name
for j in model.joints()
for j in model.joints
if {j.parent, j.child}.issubset(all_node_names_in_tree)
]
joints_in_tree = [j for j in model.joints() if j.name in joints_in_tree_names]
joints_in_tree = [j for j in model.joints if j.name in joints_in_tree_names]

# ... and those that are not
joints_not_in_tree = [
j for j in model.joints() if j.name not in joints_in_tree_names
j for j in model.joints if j.name not in joints_in_tree_names
]

# A valid rod.Model does not have any dangling link and any unconnected joints.
# Here we check that the rod.Model contains a valid tree representation.
found_num_extra_joints = len(joints_not_in_tree)
expected_num_extra_joints = 1 if model.is_fixed_base() else 0
expected_num_extra_joints = 1 if model.is_fixed_base else 0

if found_num_extra_joints != expected_num_extra_joints:
if model.is_fixed_base() and found_num_extra_joints == 0:
if model.is_fixed_base and found_num_extra_joints == 0:
raise RuntimeError("Failed to find joint connecting the model to world")

unexpected_joint_names = [j.name for j in joints_not_in_tree]
raise RuntimeError(f"Found unexpected joints: {unexpected_joint_names}")

# Handle connection to world of fixed-base models
if model.is_fixed_base():
if model.is_fixed_base:
assert len(joints_not_in_tree) == 1
world_to_base_joint = joints_not_in_tree[0]

Expand Down
6 changes: 3 additions & 3 deletions src/rod/kinematics/tree_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _compute_transform(self, name: str) -> npt.NDArray:
assert relative_to in {None, ""}, (relative_to, name)
return self.kinematic_tree.model.pose.transform()

case name if name in self.kinematic_tree.joint_names():
case name if name in self.kinematic_tree.joint_names:

edge = self.kinematic_tree.joints_dict[name]
assert edge.name() == name
Expand All @@ -66,7 +66,7 @@ def _compute_transform(self, name: str) -> npt.NDArray:

return W_H_E

case name if name in self.kinematic_tree.link_names():
case name if name in self.kinematic_tree.link_names:

element = self.kinematic_tree.links_dict[name]

Expand All @@ -81,7 +81,7 @@ def _compute_transform(self, name: str) -> npt.NDArray:
W_H_L = W_H_x @ x_H_L
return W_H_L

case name if name in self.kinematic_tree.frame_names():
case name if name in self.kinematic_tree.frame_names:

element = self.kinematic_tree.frames_dict[name]

Expand Down
1 change: 1 addition & 0 deletions src/rod/sdf/link.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses

import mashumaro
import numpy as np
import numpy.typing as npt
Expand Down
17 changes: 12 additions & 5 deletions src/rod/sdf/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import dataclasses
import functools

import mashumaro

import rod
Expand Down Expand Up @@ -62,24 +64,26 @@ class Model(Element):

joint: Joint | list[Joint] | None = dataclasses.field(default=None)

@functools.cached_property
def is_fixed_base(self) -> bool:
joints_having_world_parent = [j for j in self.joints() if j.parent == "world"]
joints_having_world_parent = [j for j in self.joints if j.parent == "world"]
assert len(joints_having_world_parent) in {0, 1}

return len(joints_having_world_parent) > 0

def get_canonical_link(self) -> str:
if len(self.models()) != 0:
if len(self.models) != 0:
msg = "Model composition is not yet supported."
msg += " The returned canonical link could be wrong."
logging.warning(msg=msg)

if self.canonical_link is not None:
assert self.canonical_link in {l.name for l in self.links()}
assert self.canonical_link in {l.name for l in self.links}
return self.canonical_link

return self.links()[0].name
return self.links[0].name

@functools.cached_property
def models(self) -> list[Model]:
if self.model is None:
return []
Expand All @@ -90,6 +94,7 @@ def models(self) -> list[Model]:
assert isinstance(self.model, list)
return self.model

@functools.cached_property
def frames(self) -> list[Frame]:
if self.frame is None:
return []
Expand All @@ -100,6 +105,7 @@ def frames(self) -> list[Frame]:
assert isinstance(self.frame, list)
return self.frame

@functools.cached_property
def links(self) -> list[Link]:
if self.link is None:
return []
Expand All @@ -110,6 +116,7 @@ def links(self) -> list[Link]:
assert isinstance(self.link, list), type(self.link)
return self.link

@functools.cached_property
def joints(self) -> list[Joint]:
if self.joint is None:
return []
Expand Down Expand Up @@ -137,7 +144,7 @@ def add_frame(self, frame: Frame) -> None:
def resolve_uris(self) -> None:
from rod.utils import resolve_uris

for link in self.links():
for link in self.links:
for visual in link.visuals():
resolve_uris.resolve_geometry_uris(geometry=visual.geometry)

Expand Down
22 changes: 11 additions & 11 deletions src/rod/urdf/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str:
model.resolve_frames(is_top_level=True, explicit_frames=False)

# Model composition is not supported, ignoring sub-models
if len(model.models()) > 0:
if len(model.models) > 0:
msg = f"Ignoring unsupported sub-models of model '{model.name}'"
logging.warning(msg=msg)

Expand All @@ -94,7 +94,7 @@ def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str:
# If the model pose is not zero, warn that it will be ignored.
# In fact, the pose wrt world of the canonical link (base) will be used instead.
if (
model.is_fixed_base()
model.is_fixed_base
and model.pose is not None
and not np.allclose(model.pose.pose, np.zeros(6))
):
Expand All @@ -103,7 +103,7 @@ def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str:

# Get the canonical link of the model
logging.debug(f"Detected '{model.get_canonical_link()}' as root link")
canonical_link: rod.Link = {l.name: l for l in model.links()}[
canonical_link: rod.Link = {l.name: l for l in model.links}[
model.get_canonical_link()
]

Expand All @@ -113,7 +113,7 @@ def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str:
# of a model, instead in URDF this reference is represented by the root link
# (that is, by definition, the SDF canonical link).
if (
not model.is_fixed_base()
not model.is_fixed_base
and canonical_link.pose is not None
and not np.allclose(canonical_link.pose.pose, np.zeros(6))
):
Expand Down Expand Up @@ -141,7 +141,7 @@ def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str:

# Since URDF does not support plain frames as SDF, we convert all frames
# to (fixed_joint->dummy_link) sequences
for frame in model.frames():
for frame in model.frames:

# New dummy link with same name of the frame
dummy_link = {
Expand Down Expand Up @@ -203,7 +203,7 @@ def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str:
# If it is a boolean, automatically populate the list with all fixed joints.
if gazebo_preserve_fixed_joints is True:
gazebo_preserve_fixed_joints = [
j.name for j in model.joints() if j.type == "fixed"
j.name for j in model.joints if j.type == "fixed"
]

if gazebo_preserve_fixed_joints is False:
Expand All @@ -214,7 +214,7 @@ def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str:
# Check that all fixed joints to preserve are actually present in the model.
for fixed_joint_name in gazebo_preserve_fixed_joints:
logging.debug(f"Preserving fixed joint '{fixed_joint_name}'")
all_model_joint_names = {j.name for j in model.joints()}
all_model_joint_names = {j.name for j in model.joints}
if fixed_joint_name not in all_model_joint_names:
raise RuntimeError(f"Joint '{fixed_joint_name}' not found in the model")

Expand All @@ -223,7 +223,7 @@ def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str:
# ===================

# In URDF, links are directly attached to the frame of their parent joint
for link in model.links():
for link in model.links:
if link.pose is not None and not np.allclose(link.pose.pose, np.zeros(6)):
msg = "Ignoring non-trivial pose of link '{name}'"
logging.warning(msg.format(name=link.name))
Expand All @@ -237,7 +237,7 @@ def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str:
"robot": {
**{"@name": model.name},
# http://wiki.ros.org/urdf/XML/link
"link": ([world_link.to_dict()] if model.is_fixed_base() else [])
"link": ([world_link.to_dict()] if model.is_fixed_base else [])
+ [
{
"@name": l.name,
Expand Down Expand Up @@ -292,7 +292,7 @@ def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str:
for c in l.collisions()
],
}
for l in model.links()
for l in model.links
]
# Add the extra links resulting from the frame->dummy_link conversion
+ extra_links_from_frames,
Expand Down Expand Up @@ -372,7 +372,7 @@ def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str:
# mimic: does not have any SDF corresponding element
# safety_controller: does not have any SDF corresponding element
}
for j in model.joints()
for j in model.joints
if j.type in UrdfExporter.SupportedSdfJointTypes
]
# Add the extra joints resulting from the frame->link conversion
Expand Down
Loading

0 comments on commit 2feb8b0

Please sign in to comment.