Skip to content

Commit

Permalink
Fix ROS 2 message equality check (#964)
Browse files Browse the repository at this point in the history
### Public-Facing Changes

Fix ROS 2 message equality check

### Description
Use a custom comparison function that compares messages by their
`__slots__` fields.

Fixes #959
Resolves FG-4775
  • Loading branch information
achim-k authored Sep 6, 2023
1 parent c75a0aa commit 1f3a471
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/mcap-ros2-support/mcap_ros2/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.5.1"
__version__ = "0.5.2"
24 changes: 24 additions & 0 deletions python/mcap-ros2-support/mcap_ros2/_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ def _read_complex_type(
"__slots__": [field.name for field in msgdef.fields],
"__repr__": __repr__,
"__str__": __repr__,
"__eq__": __eq__,
"__ne__": __ne__,
"_type": str(msgdef.base_type),
"_full_text": str(msgdef),
},
Expand Down Expand Up @@ -605,3 +607,25 @@ def _coerce_values(
def __repr__(self: Any) -> str:
fields = ", ".join(f"{field}={getattr(self, field)}" for field in self.__slots__)
return f"{self.__name__}({fields})"


def __eq__(self: Any, other: Any) -> bool:
if not isinstance(other, type(self)):
return False

if (
not hasattr(self, "__slots__")
or not hasattr(other, "__slots__")
or len(self.__slots__) != len(other.__slots__)
):
return False

for attr in self.__slots__:
if getattr(self, attr) != getattr(other, attr):
return False

return True


def __ne__(self: Any, other: Any) -> bool:
return not __eq__(self, other)
13 changes: 13 additions & 0 deletions python/mcap-ros2-support/tests/test_ros2_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,16 @@ def test_ros2_decoder():
assert ros_msg._full_text == "# std_msgs/Empty"
count += 1
assert count == 10


def test_ros2_decoder_msg_eq():
with generate_sample_data() as m:
reader = make_reader(m, decoder_factories=[DecoderFactory()])

decoded_messages = reader.iter_decoded_messages("/chatter")
_, _, _, msg0 = next(decoded_messages)
_, _, _, msg1 = next(decoded_messages)
assert msg0.data == "string message 0"
assert msg1.data == "string message 1"
assert msg0 == msg0 and msg1 == msg1
assert msg0 != msg1 and msg1 != msg0

0 comments on commit 1f3a471

Please sign in to comment.