From 79505c84cd2909fa9c4a11874db34e27ecf3a687 Mon Sep 17 00:00:00 2001 From: Christopher Dilks Date: Tue, 21 Nov 2023 11:57:14 -0500 Subject: [PATCH] feat: add event builder filter algorithm (#9) --- meson.build | 5 +- .../EventBuilderFilter.cc | 70 +++++++++++++++++++ .../event_builder_filter/EventBuilderFilter.h | 30 ++++++++ .../clas12/event_builder_filter/README.md | 3 + .../{clas12/fiducial_cuts => }/meson.build | 4 +- src/iguana/Arbiter.cc | 2 +- src/iguana/Arbiter.h | 4 +- src/iguana/meson.build | 1 + src/services/Algorithm.cc | 50 ++++++++++++- src/services/Algorithm.h | 60 +++++++++++++++- src/services/Logger.cc | 4 ++ src/services/Logger.h | 31 ++++---- src/services/meson.build | 4 +- src/tests/main.cc | 22 +++++- 14 files changed, 257 insertions(+), 33 deletions(-) create mode 100644 src/algorithms/clas12/event_builder_filter/EventBuilderFilter.cc create mode 100644 src/algorithms/clas12/event_builder_filter/EventBuilderFilter.h create mode 100644 src/algorithms/clas12/event_builder_filter/README.md rename src/algorithms/{clas12/fiducial_cuts => }/meson.build (78%) diff --git a/meson.build b/meson.build index 8134c3b7..a1048224 100644 --- a/meson.build +++ b/meson.build @@ -2,7 +2,8 @@ project( 'iguana', 'cpp', version: '0.0.0', - license: 'LGPLv3' + license: 'LGPLv3', + default_options: [ 'cpp_std=c++20' ], ) project_inc = include_directories('src') @@ -20,6 +21,6 @@ fmt_dep = dependency('fmt') hipo_dep = dependency('hipo4', method: 'cmake', cmake_args: '-DCMAKE_PREFIX_PATH=' + get_option('hipo')) subdir('src/services') -subdir('src/algorithms/clas12/fiducial_cuts') +subdir('src/algorithms') subdir('src/iguana') subdir('src/tests') diff --git a/src/algorithms/clas12/event_builder_filter/EventBuilderFilter.cc b/src/algorithms/clas12/event_builder_filter/EventBuilderFilter.cc new file mode 100644 index 00000000..32c55b9d --- /dev/null +++ b/src/algorithms/clas12/event_builder_filter/EventBuilderFilter.cc @@ -0,0 +1,70 @@ +#include "EventBuilderFilter.h" + +namespace iguana::clas12 { + + void EventBuilderFilter::Start() { + m_log->Debug("START {}", m_name); + + // set configuration + m_log->SetLevel(Logger::Level::trace); + m_opt.mode = EventBuilderFilterOptions::Modes::blank; + m_opt.pids = {11, 211, -211}; + } + + + Algorithm::BankMap EventBuilderFilter::Run(Algorithm::BankMap inBanks) { + m_log->Debug("RUN {}", m_name); + + // check the input banks existence + if(MissingInputBanks(inBanks, {"particles"})) + Throw("missing input banks"); + + // define the output schemata and banks + BankMap outBanks = { + { "particles", hipo::bank(inBanks.at("particles").getSchema()) } + }; + + // filter the input bank for requested PDG code(s) + std::set acceptedRows; + for(int row = 0; row < inBanks.at("particles").getRows(); row++) { + auto pid = inBanks.at("particles").get("pid", row); + auto accept = m_opt.pids.contains(pid); + if(accept) acceptedRows.insert(row); + m_log->Debug("input PID {} -- accept = {}", pid, accept); + } + + // fill the output bank + switch(m_opt.mode) { + + case EventBuilderFilterOptions::Modes::blank: + outBanks.at("particles").setRows(inBanks.at("particles").getRows()); + for(int row = 0; row < inBanks.at("particles").getRows(); row++) { + if(acceptedRows.contains(row)) + CopyBankRow(inBanks.at("particles"), row, outBanks.at("particles"), row); + else + BlankRow(outBanks.at("particles"), row); + } + break; + + case EventBuilderFilterOptions::Modes::compact: + outBanks.at("particles").setRows(acceptedRows.size()); + for(int row = 0; auto acceptedRow : acceptedRows) + CopyBankRow(inBanks.at("particles"), acceptedRow, outBanks.at("particles"), row++); + break; + + default: + Throw("unknown 'mode' option"); + + } + + // dump the banks and return the output + ShowBanks(inBanks, outBanks); + return outBanks; + } + + + void EventBuilderFilter::Stop() { + m_log->Debug("STOP {}", m_name); + } + +} diff --git a/src/algorithms/clas12/event_builder_filter/EventBuilderFilter.h b/src/algorithms/clas12/event_builder_filter/EventBuilderFilter.h new file mode 100644 index 00000000..89b249b7 --- /dev/null +++ b/src/algorithms/clas12/event_builder_filter/EventBuilderFilter.h @@ -0,0 +1,30 @@ +#pragma once + +#include "services/Algorithm.h" + +namespace iguana::clas12 { + + class EventBuilderFilterOptions { + public: + enum Modes { blank, compact }; + Modes mode = blank; + std::set pids = {11, 211}; + }; + + + class EventBuilderFilter : public Algorithm { + + public: + EventBuilderFilter() : Algorithm("event_builder_filter") {} + ~EventBuilderFilter() {} + + void Start() override; + Algorithm::BankMap Run(Algorithm::BankMap inBanks) override; + void Stop() override; + + private: + EventBuilderFilterOptions m_opt; + + }; + +} diff --git a/src/algorithms/clas12/event_builder_filter/README.md b/src/algorithms/clas12/event_builder_filter/README.md new file mode 100644 index 00000000..79958d20 --- /dev/null +++ b/src/algorithms/clas12/event_builder_filter/README.md @@ -0,0 +1,3 @@ +# Event Builder Filter + +Filters a particle bank for specific Event Builder PDGs diff --git a/src/algorithms/clas12/fiducial_cuts/meson.build b/src/algorithms/meson.build similarity index 78% rename from src/algorithms/clas12/fiducial_cuts/meson.build rename to src/algorithms/meson.build index cc2767fe..fa9f2f32 100644 --- a/src/algorithms/clas12/fiducial_cuts/meson.build +++ b/src/algorithms/meson.build @@ -1,9 +1,9 @@ algo_headers = [ - 'FiducialCuts.h', + 'clas12/event_builder_filter/EventBuilderFilter.h', ] algo_sources = [ - 'FiducialCuts.cc', + 'clas12/event_builder_filter/EventBuilderFilter.cc', ] algo_lib = shared_library( diff --git a/src/iguana/Arbiter.cc b/src/iguana/Arbiter.cc index a2cf801c..e0e1f7dc 100644 --- a/src/iguana/Arbiter.cc +++ b/src/iguana/Arbiter.cc @@ -3,7 +3,7 @@ namespace iguana { Arbiter::Arbiter() { - algo_map.insert({clas12_FiducialCuts, std::make_shared()}); + algo_map.insert({clas12_EventBuilderFilter, std::make_shared()}); } } diff --git a/src/iguana/Arbiter.h b/src/iguana/Arbiter.h index 20ba893b..d3fd06ec 100644 --- a/src/iguana/Arbiter.h +++ b/src/iguana/Arbiter.h @@ -5,7 +5,7 @@ #include // TODO: avoid listing the algos -#include "algorithms/clas12/fiducial_cuts/FiducialCuts.h" +#include "algorithms/clas12/event_builder_filter/EventBuilderFilter.h" namespace iguana { @@ -18,7 +18,7 @@ namespace iguana { // TODO: avoid listing the algos // TODO: who should own the algorithm instances: Arbiter or the user? enum algo { - clas12_FiducialCuts + clas12_EventBuilderFilter }; // TODO: make private diff --git a/src/iguana/meson.build b/src/iguana/meson.build index 97ff1648..ffe284ff 100644 --- a/src/iguana/meson.build +++ b/src/iguana/meson.build @@ -10,6 +10,7 @@ iguana_lib = shared_library( 'Iguana', iguana_sources, include_directories: project_inc, + dependencies: [ fmt_dep, hipo_dep ], link_with: [ algo_lib, services_lib ], install: true, install_dir: project_lib_install_dir, diff --git a/src/services/Algorithm.cc b/src/services/Algorithm.cc index d6d69f5c..d5a82ede 100644 --- a/src/services/Algorithm.cc +++ b/src/services/Algorithm.cc @@ -2,8 +2,54 @@ namespace iguana { - Algorithm::Algorithm(std::string name) { - m_log = std::make_shared(name); + Algorithm::Algorithm(std::string name) : m_name(name) { + m_log = std::make_shared(m_name); + } + + bool Algorithm::MissingInputBanks(BankMap banks, std::set keys) { + for(auto key : keys) { + if(!banks.contains(key)) { + m_log->Error("Algorithm '{}' is missing the input bank '{}'", m_name, key); + m_log->Error(" => the following input banks are required by '{}':", m_name); + for(auto k : keys) + m_log->Error(" - {}", k); + return true; + } + } + return false; + } + + void Algorithm::CopyBankRow(hipo::bank srcBank, int srcRow, hipo::bank destBank, int destRow) { + // TODO: check srcBank.getSchema() == destBank.getSchema() + for(int item = 0; item < srcBank.getSchema().getEntries(); item++) { + auto val = srcBank.get(item, srcRow); + destBank.put(item, destRow, val); + } + } + + void Algorithm::BlankRow(hipo::bank bank, int row) { + for(int item = 0; item < bank.getSchema().getEntries(); item++) { + bank.put(item, row, 0); + } + } + + void Algorithm::ShowBanks(BankMap banks, std::string message, Logger::Level level) { + if(m_log->GetLevel() <= level) { + m_log->Print(level, message); + for(auto [key,bank] : banks) { + m_log->Print(level, "BANK: '{}'", key); + bank.show(); + } + } + } + + void Algorithm::ShowBanks(BankMap inBanks, BankMap outBanks, Logger::Level level) { + ShowBanks(inBanks, "===== INPUT BANKS =====", level); + ShowBanks(outBanks, "===== OUTPUT BANKS =====", level); + } + + void Algorithm::Throw(std::string message) { + throw std::runtime_error(fmt::format("CRITICAL ERROR: {}; Algorithm '{}' stopped!", message, m_name)); } } diff --git a/src/services/Algorithm.h b/src/services/Algorithm.h index d5cbf740..7b6bba97 100644 --- a/src/services/Algorithm.h +++ b/src/services/Algorithm.h @@ -1,19 +1,75 @@ #pragma once #include "Logger.h" +#include +#include namespace iguana { class Algorithm { public: + + using BankMap = std::unordered_map; + + /// Algorithm base class constructor + /// @param name the unique name for a derived class instance Algorithm(std::string name); + + /// Algorithm base class destructor + virtual ~Algorithm() {} + + /// Initialize an algorithm before any events are processed virtual void Start() = 0; - virtual int Run(int a, int b) = 0; + + /// Run an algorithm + /// @param inBanks the set of input banks + /// @return a set of output banks + virtual BankMap Run(BankMap inBanks) = 0; + + /// Finalize an algorithm after all events are processed virtual void Stop() = 0; - virtual ~Algorithm() {} protected: + + /// Check if `banks` contains all keys `keys`; this is useful for checking algorithm inputs are complete. + /// @param banks the set of (key,bank) pairs to check + /// @keys the required keys + /// @return true if `banks` is missing any keys in `keys` + bool MissingInputBanks(BankMap banks, std::set keys); + + /// Copy a row from one bank to another, assuming their schemata are equivalent + /// @param srcBank the source bank + /// @param srcRow the row in `srcBank` to copy from + /// @param destBank the destination bank + /// @param destRow the row in `destBank` to copy to + void CopyBankRow(hipo::bank srcBank, int srcRow, hipo::bank destBank, int destRow); + + /// Blank a row, setting all items to zero + /// @param bank the bank to modify + /// @param row the row to blank + void BlankRow(hipo::bank bank, int row); + + /// Dump all banks in a BankMap + /// @param banks the banks to show + /// @param message optionally print a header message + /// @param level the log level + void ShowBanks(BankMap banks, std::string message="", Logger::Level level=Logger::trace); + + /// Dump all input and output banks + /// @param inBanks the input banks + /// @param outBanks the output banks + /// @param level the log level + void ShowBanks(BankMap inBanks, BankMap outBanks, Logger::Level level=Logger::trace); + + /// Stop the algorithm and throw a runtime exception + /// @param message the error message + void Throw(std::string message); + + /// algorithm name + std::string m_name; + + /// Logger std::shared_ptr m_log; }; } diff --git a/src/services/Logger.cc b/src/services/Logger.cc index 42282c3b..9af7d518 100644 --- a/src/services/Logger.cc +++ b/src/services/Logger.cc @@ -18,4 +18,8 @@ namespace iguana { Debug("Logger '{}' set to '{}'", m_name, m_level_names.at(m_level)); } + Logger::Level Logger::GetLevel() { + return m_level; + } + } diff --git a/src/services/Logger.h b/src/services/Logger.h index d95fd53e..e64c63d3 100644 --- a/src/services/Logger.h +++ b/src/services/Logger.h @@ -23,30 +23,27 @@ namespace iguana { ~Logger() {} void SetLevel(Level lev); + Level GetLevel(); - template void Trace(std::string msg, VALUES... vals) { Print(trace, msg, vals...); } - template void Debug(std::string msg, VALUES... vals) { Print(debug, msg, vals...); } - template void Info(std::string msg, VALUES... vals) { Print(info, msg, vals...); } - template void Warn(std::string msg, VALUES... vals) { Print(warn, msg, vals...); } - template void Error(std::string msg, VALUES... vals) { Print(error, msg, vals...); } + template void Trace(std::string message, VALUES... vals) { Print(trace, message, vals...); } + template void Debug(std::string message, VALUES... vals) { Print(debug, message, vals...); } + template void Info(std::string message, VALUES... vals) { Print(info, message, vals...); } + template void Warn(std::string message, VALUES... vals) { Print(warn, message, vals...); } + template void Error(std::string message, VALUES... vals) { Print(error, message, vals...); } template - void Print(Level lev, std::string msg, VALUES... vals) { + void Print(Level lev, std::string message, VALUES... vals) { if(lev >= m_level) { - auto level_name_it = m_level_names.find(lev); - if(level_name_it == m_level_names.end()) { - Warn("Logger::Print called with unknown log level '{}'; printing as error instead", static_cast(lev)); // FIXME: static_cast -> fmt::underlying, but needs new version of fmt - Error(msg, vals...); - } else { + if(m_level_names.contains(lev)) { + auto prefix = fmt::format("[{}] [{}] ", m_level_names.at(lev), m_name); fmt::print( lev >= warn ? stderr : stdout, - fmt::format( - "[{}] [{}] {}\n", - level_name_it->second, - m_name, - fmt::format(msg, vals...) - ) + fmt::runtime(prefix + message + "\n"), + vals... ); + } else { + Warn("Logger::Print called with unknown log level '{}'; printing as error instead", static_cast(lev)); // FIXME: static_cast -> fmt::underlying, but needs new version of fmt + Error(message, vals...); } } } diff --git a/src/services/meson.build b/src/services/meson.build index c60ef1b4..96cfb9ab 100644 --- a/src/services/meson.build +++ b/src/services/meson.build @@ -12,14 +12,14 @@ services_lib = shared_library( 'IguanaServices', services_sources, include_directories: project_inc, - dependencies: fmt_dep, + dependencies: [ fmt_dep, hipo_dep ], install: true, install_dir: project_lib_install_dir, install_rpath: project_lib_rpath, ) services_dep = declare_dependency( - dependencies: fmt_dep + dependencies: [ fmt_dep, hipo_dep ] ) install_headers(services_headers, subdir : meson.project_name()) diff --git a/src/tests/main.cc b/src/tests/main.cc index f8b8763b..782e0a58 100644 --- a/src/tests/main.cc +++ b/src/tests/main.cc @@ -10,11 +10,27 @@ int main(int argc, char **argv) { reader.open(inFile.c_str()); hipo::dictionary factory; reader.readDictionary(factory); - factory.show(); + // factory.show(); + + hipo::bank particleBank(factory.getSchema("REC::Particle")); + hipo::event event; iguana::Arbiter arb; - auto algo = arb.algo_map.at(iguana::Arbiter::clas12_FiducialCuts); + auto algo = arb.algo_map.at(iguana::Arbiter::clas12_EventBuilderFilter); algo->Start(); - fmt::print("test result: {}\n", algo->Run(3,4)); + + int count = 0; + while(reader.next()) { + if(count > 3) break; + reader.read(event); + event.getStructure(particleBank); + + auto resultBank = algo->Run({{"particles", particleBank}}); + + fmt::print("BEFORE -> AFTER: {} -> {}\n", particleBank.getRows(), resultBank.at("particles").getRows()); + + count++; + } + algo->Stop(); }