diff --git a/rosbag2_cpp/src/rosbag2_cpp/message_definitions/local_message_definition_source.cpp b/rosbag2_cpp/src/rosbag2_cpp/message_definitions/local_message_definition_source.cpp new file mode 100644 index 0000000000..3382efb6a7 --- /dev/null +++ b/rosbag2_cpp/src/rosbag2_cpp/message_definitions/local_message_definition_source.cpp @@ -0,0 +1,311 @@ +// Copyright 2022, Foxglove Technologies. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rosbag2_cpp/message_definitions/local_message_definition_source.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "rosbag2_cpp/logging.hpp" + +namespace rosbag2_cpp +{ + +/// A type name did not match expectations, so a definition could not be looked for. +class TypenameNotUnderstoodError : public std::exception +{ +private: + std::string name_; + +public: + explicit TypenameNotUnderstoodError(std::string name) + : name_(std::move(name)) + {} + + const char * what() const noexcept override + { + return name_.c_str(); + } +}; + +// Match datatype names (foo_msgs/Bar or foo_msgs/msg/Bar) +static const std::regex PACKAGE_TYPENAME_REGEX{R"(^([a-zA-Z0-9_]+)/(?:msg/|srv/)?([a-zA-Z0-9_]+)$)"}; + +// Match field types from .msg and .srv definitions ("foo_msgs/Bar" in "foo_msgs/Bar[] bar") +static const std::regex MSG_FIELD_TYPE_REGEX{R"((?:^|\n)\s*([a-zA-Z0-9_/]+)(?:\[[^\]]*\])?\s+)"}; + +// match field types from `.idl` definitions ("foo_msgs/msg/bar" in #include ) +static const std::regex IDL_FIELD_TYPE_REGEX{ + R"((?:^|\n)#include\s+(?:"|<)([a-zA-Z0-9_/]+)\.idl(?:"|>))"}; + +static const std::unordered_set PRIMITIVE_TYPES{ + "bool", "byte", "char", "float32", "float64", "int8", "uint8", + "int16", "uint16", "int32", "uint32", "int64", "uint64", "string"}; + +static std::set parse_msg_dependencies( + const std::string & text, + const std::string & package_context) +{ + std::set dependencies; + + for (std::sregex_iterator iter(text.begin(), text.end(), MSG_FIELD_TYPE_REGEX); + iter != std::sregex_iterator(); ++iter) + { + std::string type = (*iter)[1]; + if (PRIMITIVE_TYPES.find(type) != PRIMITIVE_TYPES.end()) { + continue; + } + if (type.find('/') == std::string::npos) { + dependencies.insert(package_context + '/' + std::move(type)); + } else { + dependencies.insert(std::move(type)); + } + } + return dependencies; +} + +static std::set parse_idl_dependencies(const std::string & text) +{ + std::set dependencies; + + for (std::sregex_iterator iter(text.begin(), text.end(), IDL_FIELD_TYPE_REGEX); + iter != std::sregex_iterator(); ++iter) + { + dependencies.insert((*iter)[1]); + } + return dependencies; +} + +std::set parse_definition_dependencies( + LocalMessageDefinitionSource::Format format, + const std::string & text, + const std::string & package_context) +{ + switch (format) { + case LocalMessageDefinitionSource::Format::MSG: + return parse_msg_dependencies(text, package_context); + case LocalMessageDefinitionSource::Format::IDL: + return parse_idl_dependencies(text); + case LocalMessageDefinitionSource::Format::SRV: + { + auto dep = parse_msg_dependencies(text, package_context); + if (!dep.empty()) { + return dep; + } else { + return parse_idl_dependencies(text); + } + } + default: + throw std::runtime_error("switch is not exhaustive"); + } +} + +static const char * extension_for_format(LocalMessageDefinitionSource::Format format) +{ + switch (format) { + case LocalMessageDefinitionSource::Format::MSG: + return ".msg"; + case LocalMessageDefinitionSource::Format::IDL: + return ".idl"; + case LocalMessageDefinitionSource::Format::SRV: + return ".srv"; + default: + throw std::runtime_error("switch is not exhaustive"); + } +} + +std::string LocalMessageDefinitionSource::delimiter( + const DefinitionIdentifier & definition_identifier) +{ + std::string result = + "================================================================================\n"; + switch (definition_identifier.format()) { + case Format::MSG: + result += "MSG: "; + break; + case Format::IDL: + result += "IDL: "; + break; + case Format::SRV: + result += "SRV: "; + break; + default: + throw std::runtime_error("switch is not exhaustive"); + } + result += definition_identifier.topic_type(); + result += "\n"; + return result; +} + +LocalMessageDefinitionSource::MessageSpec::MessageSpec( + Format format, std::string text, + const std::string & package_context) +: dependencies(parse_definition_dependencies(format, text, package_context)) + , text(std::move(text)) + , format(format) +{ +} + +const LocalMessageDefinitionSource::MessageSpec & LocalMessageDefinitionSource::load_message_spec( + const DefinitionIdentifier & definition_identifier) +{ + if (auto it = msg_specs_by_definition_identifier_.find(definition_identifier); + it != msg_specs_by_definition_identifier_.end()) + { + return it->second; + } + std::smatch match; + const auto topic_type = definition_identifier.topic_type(); + if (!std::regex_match(topic_type, match, PACKAGE_TYPENAME_REGEX)) { + throw TypenameNotUnderstoodError(topic_type); + } + std::string package = match[1]; + std::string share_dir; + try { + share_dir = ament_index_cpp::get_package_share_directory(package); + } catch (const ament_index_cpp::PackageNotFoundError & e) { + ROSBAG2_CPP_LOG_WARN("'%s'", e.what()); + throw DefinitionNotFoundError(definition_identifier.topic_type()); + } + std::string dir = definition_identifier.format() == Format::MSG || + definition_identifier.format() == Format::IDL ? "/msg/" : "/srv/"; + std::ifstream file{share_dir + dir + match[2].str() + + extension_for_format(definition_identifier.format())}; + if (!file.good()) { + throw DefinitionNotFoundError(definition_identifier.topic_type()); + } + + std::string contents{std::istreambuf_iterator(file), {}}; + const MessageSpec & spec = msg_specs_by_definition_identifier_.emplace( + definition_identifier, + MessageSpec(definition_identifier.format(), std::move(contents), package)).first->second; + + // "References and pointers to data stored in the container are only invalidated by erasing that + // element, even when the corresponding iterator is invalidated." + return spec; +} + +rosbag2_storage::MessageDefinition LocalMessageDefinitionSource::get_full_text( + const std::string & root_type) +{ + std::unordered_set seen_deps; + + std::function append_recursive = + [&](const DefinitionIdentifier & definition_identifier, int32_t depth) { + if (depth <= 0) { + throw std::runtime_error{ + "Reached max recursion depth resolving definition of " + root_type}; + } + const MessageSpec & spec = load_message_spec(definition_identifier); + std::string result = spec.text; + for (const auto & dep_name : spec.dependencies) { + DefinitionIdentifier dep(dep_name, definition_identifier.format()); + bool inserted = seen_deps.insert(dep).second; + if (inserted) { + result += "\n"; + result += delimiter(dep); + result += append_recursive(dep, depth - 1); + } + } + return result; + }; + + std::string result; + Format format = Format::UNKNOWN; + int32_t max_recursion_depth = ROSBAG2_CPP_LOCAL_MESSAGE_DEFINITION_SOURCE_MAX_RECURSION_DEPTH; + + if (root_type.find("/srv/") == std::string::npos) { // Not a service + try { + format = Format::MSG; + result = append_recursive(DefinitionIdentifier(root_type, format), max_recursion_depth); + } catch (const DefinitionNotFoundError & err) { + ROSBAG2_CPP_LOG_WARN("No .msg definition for %s, falling back to IDL", err.what()); + format = Format::IDL; + DefinitionIdentifier root_definition_identifier(root_type, format); + result = (delimiter(root_definition_identifier) + + append_recursive(root_definition_identifier, max_recursion_depth)); + } catch (const TypenameNotUnderstoodError & err) { + ROSBAG2_CPP_LOG_ERROR( + "Message type name '%s' not understood by type definition search, " + "definition will be left empty in bag.", err.what()); + format = Format::UNKNOWN; + } + } else { + // The service dependencies could be either in the msg or idl files. Therefore, will try to + // search service dependencies in MSG files first then in IDL files via two separate recursive + // searches for each dependency. + format = Format::UNKNOWN; + DefinitionIdentifier def_identifier{root_type, Format::SRV}; + (void)seen_deps.insert(def_identifier).second; + result = delimiter(def_identifier); + const MessageSpec & spec = load_message_spec(def_identifier); + result += spec.text; + for (const auto & dep_name : spec.dependencies) { + DefinitionIdentifier dep(dep_name, Format::MSG); + bool inserted = seen_deps.insert(dep).second; + if (inserted) { + try { + result += "\n"; + result += delimiter(dep); + result += append_recursive(dep, max_recursion_depth); + format = Format::MSG; + } catch (const DefinitionNotFoundError & err) { + ROSBAG2_CPP_LOG_WARN("No .msg definition for %s, falling back to IDL", err.what()); + dep = DefinitionIdentifier(dep_name, Format::IDL); + inserted = seen_deps.insert(dep).second; + if (inserted) { + result += "\n"; + result += delimiter(dep); + result += append_recursive(dep, max_recursion_depth); + format = Format::IDL; + } + } catch (const TypenameNotUnderstoodError & err) { + ROSBAG2_CPP_LOG_ERROR( + "Message type name '%s' not understood by type definition search, " + "definition will be left empty in bag.", err.what()); + format = Format::UNKNOWN; + } + } + } + } + rosbag2_storage::MessageDefinition out; + switch (format) { + case Format::UNKNOWN: + out.encoding = "unknown"; + break; + case Format::MSG: + case Format::SRV: + out.encoding = "ros2msg"; + break; + case Format::IDL: + out.encoding = "ros2idl"; + break; + default: + throw std::runtime_error("switch is not exhaustive"); + } + + out.encoded_message_definition = result; + out.topic_type = root_type; + return out; +} +} // namespace rosbag2_cpp diff --git a/rosbag2_cpp/test/rosbag2_cpp/test_local_message_definition_source.cpp b/rosbag2_cpp/test/rosbag2_cpp/test_local_message_definition_source.cpp new file mode 100644 index 0000000000..710f99b225 --- /dev/null +++ b/rosbag2_cpp/test/rosbag2_cpp/test_local_message_definition_source.cpp @@ -0,0 +1,198 @@ +// Copyright 2022, Foxglove Technologies. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gmock/gmock.h" +#include "rosbag2_cpp/message_definitions/local_message_definition_source.hpp" + +using rosbag2_cpp::LocalMessageDefinitionSource; +using rosbag2_cpp::parse_definition_dependencies; +using ::testing::UnorderedElementsAre; + +TEST(test_local_message_definition_source, can_find_idl_includes) +{ + const char sample[] = + R"r( +#include "rosbag2_test_msgdefs/msg/BasicIdlA.idl" + +#include + +module rosbag2_test_msgdefs { + module msg { + struct ComplexIdl { + rosbag2_test_msgdefs::msg::BasicIdlA a; + rosbag2_test_msgdefs::msg::BasicIdlB b; + }; + }; +}; + + )r"; + std::set dependencies = parse_definition_dependencies( + LocalMessageDefinitionSource::Format::IDL, sample, ""); + ASSERT_THAT( + dependencies, UnorderedElementsAre( + "rosbag2_test_msgdefs/msg/BasicIdlA", + "rosbag2_test_msgdefs/msg/BasicIdlB")); +} + +TEST(test_local_message_definition_source, can_find_msg_deps) +{ + LocalMessageDefinitionSource source; + auto result = source.get_full_text("rosbag2_test_msgdefs/ComplexMsg"); + ASSERT_EQ(result.encoding, "ros2msg"); + ASSERT_EQ( + result.encoded_message_definition, + "rosbag2_test_msgdefs/BasicMsg b\n" + "\n" + "================================================================================\n" + "MSG: rosbag2_test_msgdefs/BasicMsg\n" + "float32 c\n"); +} + +TEST(test_local_message_definition_source, can_find_srv_deps_in_msg) +{ + LocalMessageDefinitionSource source; + auto result = source.get_full_text("rosbag2_test_msgdefs/srv/ComplexSrvMsg"); + ASSERT_EQ(result.encoding, "ros2msg"); + ASSERT_EQ( + result.encoded_message_definition, + "================================================================================\n" + "SRV: rosbag2_test_msgdefs/srv/ComplexSrvMsg\n" + "rosbag2_test_msgdefs/BasicMsg req\n" + "---\n" + "rosbag2_test_msgdefs/BasicMsg resp\n" + "\n" + "================================================================================\n" + "MSG: rosbag2_test_msgdefs/BasicMsg\n" + "float32 c\n") << result.encoded_message_definition << std::endl; +} + +TEST(test_local_message_definition_source, can_find_srv_deps_in_idl) +{ + LocalMessageDefinitionSource source; + auto result = source.get_full_text("rosbag2_test_msgdefs/srv/ComplexSrvIdl"); + ASSERT_EQ(result.encoding, "ros2idl"); + ASSERT_EQ( + result.encoded_message_definition, + "================================================================================\n" + "SRV: rosbag2_test_msgdefs/srv/ComplexSrvIdl\n" + "rosbag2_test_msgdefs/BasicIdl req\n" + "---\n" + "rosbag2_test_msgdefs/BasicIdl resp\n" + "\n" + "================================================================================\n" + "MSG: rosbag2_test_msgdefs/BasicIdl\n" + "\n" + "================================================================================\n" + "IDL: rosbag2_test_msgdefs/BasicIdl\n" + "module rosbag2_test_msgdefs {\n" + " module msg {\n" + " struct BasicIdl {\n" + " float x;\n" + " };\n" + " };\n" + "};\n") << result.encoded_message_definition << std::endl; +} + +TEST(test_local_message_definition_source, can_find_idl_deps) +{ + LocalMessageDefinitionSource source; + auto result = source.get_full_text("rosbag2_test_msgdefs/msg/ComplexIdl"); + ASSERT_EQ(result.encoding, "ros2idl"); + ASSERT_EQ( + result.encoded_message_definition, + "================================================================================\n" + "IDL: rosbag2_test_msgdefs/msg/ComplexIdl\n" + "#include \"rosbag2_test_msgdefs/msg/BasicIdl.idl\"\n" + "\n" + "module rosbag2_test_msgdefs {\n" + " module msg {\n" + " struct ComplexIdl {\n" + " rosbag2_test_msgdefs::msg::BasicIdl a;\n" + " };\n" + " };\n" + "};\n" + "\n" + "================================================================================\n" + "IDL: rosbag2_test_msgdefs/msg/BasicIdl\n" + "module rosbag2_test_msgdefs {\n" + " module msg {\n" + " struct BasicIdl {\n" + " float x;\n" + " };\n" + " };\n" + "};\n"); +} + +TEST(test_local_message_definition_source, can_resolve_msg_with_idl_deps) +{ + LocalMessageDefinitionSource source; + auto result = source.get_full_text("rosbag2_test_msgdefs/msg/ComplexMsgDependsOnIdl"); + ASSERT_EQ(result.encoding, "ros2idl"); + ASSERT_EQ( + result.encoded_message_definition, + "================================================================================\n" + "IDL: rosbag2_test_msgdefs/msg/ComplexMsgDependsOnIdl\n" + "// generated from rosidl_adapter/resource/msg.idl.em\n" + "// with input from rosbag2_test_msgdefs/msg/ComplexMsgDependsOnIdl.msg\n" + "// generated code does not contain a copyright notice\n" + "\n" + "#include \"rosbag2_test_msgdefs/msg/BasicIdl.idl\"\n" + "\n" + "module rosbag2_test_msgdefs {\n" + " module msg {\n" + " struct ComplexMsgDependsOnIdl {\n" + " rosbag2_test_msgdefs::msg::BasicIdl a;\n" + " };\n" + " };\n" + "};\n" + "\n" + "================================================================================\n" + "IDL: rosbag2_test_msgdefs/msg/BasicIdl\n" + "module rosbag2_test_msgdefs {\n" + " module msg {\n" + " struct BasicIdl {\n" + " float x;\n" + " };\n" + " };\n" + "};\n"); +} + +TEST(test_local_message_definition_source, no_crash_on_bad_name) +{ + LocalMessageDefinitionSource source; + rosbag2_storage::MessageDefinition result; + ASSERT_NO_THROW( + { + result = source.get_full_text("rosbag2_test_msgdefs/idl/BasicSrv_Request"); + }); + ASSERT_EQ(result.encoding, "unknown"); +} + +TEST(test_local_message_definition_source, throw_definition_not_found_for_unknown_msg) +{ + LocalMessageDefinitionSource source; + ASSERT_THROW( + { + source.get_full_text("rosbag2_test_msgdefs/msg/UnknownMessage"); + }, rosbag2_cpp::DefinitionNotFoundError); + + // Throw DefinitionNotFoundError for not found message definition package name + ASSERT_THROW( + { + source.get_full_text("not_found_msgdefs_pkg/msg/UnknownMessage"); + }, rosbag2_cpp::DefinitionNotFoundError); +} diff --git a/rosbag2_transport/src/rosbag2_transport/bag_rewrite.cpp b/rosbag2_transport/src/rosbag2_transport/bag_rewrite.cpp index 2ef591131d..70962699b5 100644 --- a/rosbag2_transport/src/rosbag2_transport/bag_rewrite.cpp +++ b/rosbag2_transport/src/rosbag2_transport/bag_rewrite.cpp @@ -103,7 +103,7 @@ setup_topic_filtering( } for (const auto & [writer, record_options] : output_bags) { - rosbag2_transport::TopicFilter topic_filter{record_options}; + rosbag2_transport::TopicFilter topic_filter{record_options, nullptr, true}; auto filtered_topics_and_types = topic_filter.filter_topics(input_topics); // Done filtering - set up writer