Skip to content

Commit

Permalink
reformat module registry printing a
Browse files Browse the repository at this point in the history
nd add utest
  • Loading branch information
Schmluk committed Apr 25, 2024
1 parent 2201be0 commit 2bb8d9f
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 15 deletions.
46 changes: 31 additions & 15 deletions config_utilities/include/config_utilities/factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

#include <algorithm>
#include <functional>
#include <map>
#include <memory>
#include <sstream>
#include <string>
Expand All @@ -58,23 +59,34 @@ namespace internal {

class ModuleRegistry {
public:
static void addModule(const std::string& type, const std::string& type_info) {
instance().modules.emplace_back(type, type_info);
template <typename BaseT, typename... Args>
static void addModule(const std::string& type) {
const std::string base_type = typeName<BaseT>();
std::stringstream ss;
((ss << typeName<Args>() << ", "), ...);
std::string arguments = ss.str();
if (!arguments.empty()) {
arguments = arguments.substr(0, arguments.size() - 2);
}

auto& types = instance().modules[base_type][arguments];
types.push_back(type);
std::sort(types.begin(), types.end());
}

static std::string getAllRegistered() {
std::stringstream ss;
ss << "Modules registered to factories: {";
auto modules = instance().modules;
std::sort(modules.begin(), modules.end());
if (!modules.empty()) {
ss << "\n";
}

for (auto&& [type, type_info] : modules) {
ss << " " << type << ": \"" << type_info << "\",\n";
for (auto&& [base_type, args] : instance().modules) {
for (auto&& [arguments, types] : args) {
ss << "\n " << base_type << "(" << arguments << "): {";
for (auto&& type : types) {
ss << "\n '" << type << "', ";
}
ss << "\n },";
}
}
ss << "}";
ss << "\n}";
return ss.str();
}

Expand All @@ -86,7 +98,8 @@ class ModuleRegistry {

ModuleRegistry() = default;

std::vector<std::pair<std::string, std::string>> modules;
// Nested modules: base_type -> args -> registered types.
std::map<std::string, std::map<std::string, std::vector<std::string>>> modules;
};

// Struct to store the factory methods for the creation of modules.
Expand All @@ -105,7 +118,6 @@ struct ModuleMapBase {
}

map.insert(std::make_pair(type, method));
ModuleRegistry::addModule(type, type_info);
return true;
}

Expand Down Expand Up @@ -249,7 +261,9 @@ struct ObjectFactory {
template <typename DerivedT>
static void addEntry(const std::string& type) {
FactoryMethod method = [](Args... args) { return new DerivedT(args...); };
ModuleMap::addEntry(type, method, typeInfo<BaseT, Args...>());
if (ModuleMap::addEntry(type, method, typeInfo<BaseT, Args...>())) {
ModuleRegistry::addModule<BaseT, Args...>(type);
}
}

static std::unique_ptr<BaseT> create(const std::string& type, Args... args) {
Expand All @@ -276,7 +290,9 @@ struct ObjectWithConfigFactory {
Visitor::setValues(config, data);
return new DerivedT(config, args...);
};
ModuleMap::addEntry(type, method, typeInfo<BaseT, Args...>());
if (ModuleMap::addEntry(type, method, typeInfo<BaseT, Args...>())) {
ModuleRegistry::addModule<BaseT, Args...>(type);
}
}

static std::unique_ptr<BaseT> create(const YAML::Node& data, Args... args) {
Expand Down
2 changes: 2 additions & 0 deletions config_utilities/test/include/config_utilities/test/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class TestLogger : public internal::Logger {
int numMessages() const { return messages_.size(); }
void clear() { messages_.clear(); }
void print() const;
bool hasMessages() const { return !messages_.empty(); }
const std::string& lastMessage() const { return messages_.back().second; }

static std::shared_ptr<TestLogger> create();

Expand Down
64 changes: 64 additions & 0 deletions config_utilities/test/tests/factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ void declare_config(DerivedD::Config& config) {
config::field(config.i, "i");
}

template <typename T>
struct TemplatedBase {
virtual ~TemplatedBase() = default;
};

template <typename DerivedT, typename BaseT>
struct TemplatedDerived : public TemplatedBase<BaseT> {};

TEST(Factory, create) {
std::unique_ptr<Base> base = create<Base>("DerivedA", 1);
EXPECT_TRUE(base);
Expand Down Expand Up @@ -164,4 +172,60 @@ TEST(Factory, createWithConfig) {
EXPECT_EQ(dynamic_cast<DerivedD*>(base.get())->config_.i, 3);
}

TEST(Factory, moduleNameConflicts) {
auto logger = TestLogger::create();

// Allow shadowing of same name for different module types.
const auto registration1 = config::Registration<TemplatedBase<int>, TemplatedDerived<int, int>>("name");
const auto registration2 = config::Registration<TemplatedBase<float>, TemplatedDerived<float, float>>("name");
EXPECT_EQ(logger->numMessages(), 0);

// Same derived different name. NOTE(lschmid): This is allowed, not sure if we would want to warn users though.
const auto registration3 = config::Registration<TemplatedBase<int>, TemplatedDerived<int, int>>("other_name");
EXPECT_FALSE(logger->hasMessages());

// Same derived same name. Not allowed. NOTE(lschmid): Could also be an option to make this allowed (skip silently).
const auto registration4 = config::Registration<TemplatedBase<int>, TemplatedDerived<int, int>>("name");
EXPECT_EQ(logger->numMessages(), 1);
EXPECT_EQ(logger->lastMessage(),
"Cannot register already existent type 'name' for BaseT='config::test::TemplatedBase<int>' and "
"ConstructorArguments={}.");

// Different derived same base and same name. Not allowed.
const auto registration5 = config::Registration<TemplatedBase<int>, TemplatedDerived<float, int>>("name");
EXPECT_EQ(logger->numMessages(), 2);
EXPECT_EQ(logger->lastMessage(),
"Cannot register already existent type 'name' for BaseT='config::test::TemplatedBase<int>' and "
"ConstructorArguments={}.");
}

TEST(Factory, printRegistryInfo) {
const std::string expected = R"""(Modules registered to factories: {
config::internal::Formatter(): {
'asl',
},
config::test::Base(int): {
'DerivedA',
'DerivedB',
'DerivedC',
'DerivedD',
},
config::test::Base2(): {
'Derived2',
},
config::test::ProcessorBase(): {
'AddString',
},
config::test::TemplatedBase<float>(): {
'name',
},
config::test::TemplatedBase<int>(): {
'name',
'other_name',
},
})""";
const std::string modules = internal::ModuleRegistry::getAllRegistered();
EXPECT_EQ(modules, expected);
}

} // namespace config::test

0 comments on commit 2bb8d9f

Please sign in to comment.