Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use hipo::banklist for mutation #32

Merged
merged 3 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace iguana::clas12 {

}

void EventBuilderFilter::Start(bank_index_cache_t &index_cache) {
void EventBuilderFilter::Start(bank_index_cache_t& index_cache) {

// define options, their default values, and cache them
CacheOption("pids", std::set<int>{11, 211}, o_pids);
Expand All @@ -23,18 +23,18 @@ namespace iguana::clas12 {
}


void EventBuilderFilter::Run(bank_vec_t banks) {
void EventBuilderFilter::Run(hipo::banklist& banks) {

// get the banks
auto particleBank = GetBank(banks, b_particle, "REC::Particle");
auto caloBank = GetBank(banks, b_calo, "REC::Calorimeter"); // TODO: remove
auto& particleBank = GetBank(banks, b_particle, "REC::Particle");
// auto& caloBank = GetBank(banks, b_calo, "REC::Calorimeter"); // TODO: remove

// dump the bank
ShowBank(particleBank, Logger::Header("INPUT PARTICLES"));

// filter the input bank for requested PDG code(s)
for(int row = 0; row < particleBank->getRows(); row++) {
auto pid = particleBank->getInt("pid", row);
for(int row = 0; row < particleBank.getRows(); row++) {
auto pid = particleBank.getInt("pid", row);
auto accept = Filter(pid);
if(!accept)
MaskRow(particleBank, row);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ namespace iguana::clas12 {
~EventBuilderFilter() {}

void Start() override { Algorithm::Start(); }
void Start(bank_index_cache_t &index_cache) override;
void Run(bank_vec_t banks) override;
void Start(bank_index_cache_t& index_cache) override;
void Run(hipo::banklist& banks) override;
void Stop() override;

bool Filter(int pid);

private:

/// `bank_vec_t` indices
/// `hipo::banklist` indices
int b_particle, b_calo; // TODO: remove calorimeter

/// configuration options
Expand Down
2 changes: 1 addition & 1 deletion src/iguana/Iguana.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace iguana {

Iguana::Iguana() {
algo_map.insert({clas12_EventBuilderFilter, std::make_shared<clas12::EventBuilderFilter>()});
algo_map.insert({clas12_EventBuilderFilter, std::move(std::make_unique<clas12::EventBuilderFilter>())});
}

}
2 changes: 1 addition & 1 deletion src/iguana/Iguana.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace iguana {
};

// TODO: make private
std::unordered_map<Iguana::algo, std::shared_ptr<Algorithm>> algo_map;
std::unordered_map<Iguana::algo, std::unique_ptr<Algorithm>> algo_map;

};
}
43 changes: 24 additions & 19 deletions src/services/Algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace iguana {

Algorithm::Algorithm(std::string name) : m_name(name) {
m_log = std::make_shared<Logger>(m_name);
m_log = std::make_unique<Logger>(m_name);
}

void Algorithm::Start() {
Expand All @@ -19,14 +19,14 @@ namespace iguana {
m_log->Debug("User set option '{}' = {}", key, PrintOptionValue(key));
}

std::shared_ptr<Logger> Algorithm::Log() {
std::unique_ptr<Logger>& Algorithm::Log() {
return m_log;
}

void Algorithm::CacheBankIndex(bank_index_cache_t index_cache, int &idx, std::string bankName) {
void Algorithm::CacheBankIndex(bank_index_cache_t index_cache, int& idx, std::string bankName) {
try {
idx = index_cache.at(bankName);
} catch(const std::out_of_range &o) {
} catch(const std::out_of_range& o) {
Throw(fmt::format("required input bank '{}' not found; cannot `Start` algorithm '{}'", bankName, m_name));
}
m_log->Debug("cached index of bank '{}' is {}", bankName, idx);
Expand All @@ -50,39 +50,44 @@ namespace iguana {
return "UNKNOWN";
}

bank_ptr Algorithm::GetBank(bank_vec_t banks, int idx, std::string expectedBankName) {
bank_ptr result;
hipo::bank& Algorithm::GetBank(hipo::banklist& banks, int idx, std::string expectedBankName) {
try {
result = banks.at(idx);
} catch(const std::out_of_range &o) {
auto& result = banks.at(idx);
if(expectedBankName != "" && result.getSchema().getName() != expectedBankName) {
Throw(fmt::format("expected input bank '{}' at index={}; got bank named '{}'", expectedBankName, idx, result.getSchema().getName()));
}
return result;
} catch(const std::out_of_range& o) {
Throw(fmt::format("required input bank '{}' not found; cannot `Run` algorithm '{}'", expectedBankName, m_name));
}
if(expectedBankName != "" && result->getSchema().getName() != expectedBankName) {
Throw(fmt::format("expected input bank '{}' at index={}; got bank named '{}'", expectedBankName, idx, result->getSchema().getName()));
}
return result;
throw std::runtime_error("GetBank failed"); // avoid `-Wreturn-type` warning
}

void Algorithm::MaskRow(bank_ptr bank, int row) {
void Algorithm::MaskRow(hipo::bank& bank, int row) {
// TODO: need https://github.com/gavalian/hipo/issues/35
// until then, just set the PID to -1
bank->putInt("pid", row, -1);
bank.putInt("pid", row, -1);
}

void Algorithm::ShowBanks(bank_vec_t banks, std::string message, Logger::Level level) {
void Algorithm::ShowBanks(hipo::banklist& banks, std::string message, Logger::Level level) {
if(m_log->GetLevel() <= level) {
if(message != "")
m_log->Print(level, message);
for(auto bank : banks)
bank->show();
for(auto& bank : banks)
bank.show();
}
}

void Algorithm::ShowBank(bank_ptr bank, std::string message, Logger::Level level) {
//
// FIXME: const protection should be applied everywhere
// it's possbile, to indicate immuatibility
//
//
c-dilks marked this conversation as resolved.
Show resolved Hide resolved
void Algorithm::ShowBank(hipo::bank& bank, std::string message, Logger::Level level) {
if(m_log->GetLevel() <= level) {
if(message != "")
m_log->Print(level, message);
bank->show();
bank.show();
}
}

Expand Down
38 changes: 19 additions & 19 deletions src/services/Algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ namespace iguana {

/// Initialize an algorithm before any events are processed
/// @param index_cache The `Run` method will use these indices to access banks
virtual void Start(bank_index_cache_t &index_cache) = 0;
virtual void Start(bank_index_cache_t& index_cache) = 0;

/// Run an algorithm
/// @param banks the set of banks to process
virtual void Run(bank_vec_t banks) = 0;
virtual void Run(hipo::banklist& banks) = 0;

/// Finalize an algorithm after all events are processed
virtual void Stop() = 0;
Expand All @@ -38,27 +38,27 @@ namespace iguana {

/// Get the logger
/// @return the logger used by this algorithm
std::shared_ptr<Logger> Log();
std::unique_ptr<Logger>& Log();

protected:

/// Cache the index of a bank in a `bank_vec_t`; throws an exception if the bank is not found
/// @param index_cache the relation between bank name and `bank_vec_t` index
/// @param idx a reference to the `bank_vec_t` index of the bank
/// Cache the index of a bank in a `hipo::banklist`; throws an exception if the bank is not found
/// @param index_cache the relation between bank name and `hipo::banklist` index
/// @param idx a reference to the `hipo::banklist` index of the bank
/// @param bankName the name of the bank
void CacheBankIndex(bank_index_cache_t index_cache, int &idx, std::string bankName);
void CacheBankIndex(bank_index_cache_t index_cache, int& idx, std::string bankName) noexcept(false);

/// Cache an option specified by the user, and define its default value
/// @param key the name of the option
/// @param def the default value
/// @param val reference to the value of the option, to be cached by `Start`
template <typename OPTION_TYPE>
void CacheOption(std::string key, OPTION_TYPE def, OPTION_TYPE &val) {
void CacheOption(std::string key, OPTION_TYPE def, OPTION_TYPE& val) {
bool get_error = false;
if(auto it{m_opt.find(key)}; it != m_opt.end()) { // cache the user's option value
try { // get the expected type
val = std::get<OPTION_TYPE>(it->second);
} catch(const std::bad_variant_access &ex1) {
} catch(const std::bad_variant_access& ex1) {
m_log->Error("user option '{}' set to '{}', which is the wrong type...", key, PrintOptionValue(key));
get_error = true;
val = def;
Expand All @@ -79,33 +79,33 @@ namespace iguana {
/// @return the string value and its type
std::string PrintOptionValue(std::string key);

/// Get the pointer to a bank from a `bank_vec_t`; optionally checks if the bank name matches the expectation
/// @param banks the `bank_vec_t` from which to get the specified bank
/// Get the pointer to a bank from a `hipo::banklist`; optionally checks if the bank name matches the expectation
/// @param banks the `hipo::banklist` from which to get the specified bank
/// @param idx the index of `banks` of the specified bank
/// @param expectedBankName if specified, checks that the specified bank has this name
/// @return the modified `bank_vec_t`
bank_ptr GetBank(bank_vec_t banks, int idx, std::string expectedBankName="");
/// @return the modified `hipo::banklist`
hipo::bank& GetBank(hipo::banklist& banks, int idx, std::string expectedBankName="") noexcept(false);

/// Mask a row, setting all items to zero
/// @param bank the bank to modify
/// @param row the row to blank
void MaskRow(bank_ptr bank, int row);
void MaskRow(hipo::bank& bank, int row);

/// Dump all banks in a `bank_vec_t`
/// Dump all banks in a `hipo::banklist`
/// @param banks the banks to show
/// @param message optionally print a header message
/// @param level the log level
void ShowBanks(bank_vec_t banks, std::string message="", Logger::Level level=Logger::trace);
void ShowBanks(hipo::banklist& banks, std::string message="", Logger::Level level=Logger::trace);

/// Dump a single bank
/// @param bank the bank to show
/// @param message optionally print a header message
/// @param level the log level
void ShowBank(bank_ptr bank, std::string message="", Logger::Level level=Logger::trace);
void ShowBank(hipo::bank& bank, std::string message="", Logger::Level level=Logger::trace);

/// Stop the algorithm and throw a runtime exception
/// @param message the error message
void Throw(std::string message);
void Throw(std::string message) noexcept(false);

/// algorithm name
std::string m_name;
Expand All @@ -114,7 +114,7 @@ namespace iguana {
std::vector<std::string> m_requiredBanks;

/// Logger
std::shared_ptr<Logger> m_log;
std::unique_ptr<Logger> m_log;

/// Configuration options
options_t m_opt;
Expand Down
2 changes: 1 addition & 1 deletion src/services/Logger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace iguana {
}

void Logger::SetLevel(std::string lev) {
for(auto &[lev_i, lev_n] : m_level_names) {
for(auto& [lev_i, lev_n] : m_level_names) {
if(lev == lev_n) {
SetLevel(lev_i);
return;
Expand Down
8 changes: 1 addition & 7 deletions src/services/TypeDefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@

namespace iguana {

/// pointer to a HIPO bank
using bank_ptr = std::shared_ptr<hipo::bank>;

/// ordered list of HIPO bank pointers
using bank_vec_t = std::vector<bank_ptr>;

/// association between HIPO bank name and its index in a `bank_vec_t`
/// association between HIPO bank name and its index in a `hipo::banklist`
using bank_index_cache_t = std::unordered_map<std::string, int>;

/// option value variant type
Expand Down
40 changes: 19 additions & 21 deletions src/tests/run_banks.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "iguana/Iguana.h"
#include <hipo4/reader.h>

void printParticles(std::string prefix, iguana::bank_ptr b) {
void printParticles(std::string prefix, hipo::bank& b) {
std::vector<int> pids;
for(int row=0; row<b->getRows(); row++)
pids.push_back(b->getInt("pid", row));
for(int row=0; row<b.getRows(); row++)
pids.push_back(b.getInt("pid", row));
fmt::print("{}: {}\n", prefix, fmt::join(pids, ", "));
}

Expand All @@ -20,7 +20,7 @@ int main(int argc, char **argv) {
* use the test algorithm directly
*/
iguana::Iguana I;
auto algo = I.algo_map.at(iguana::Iguana::clas12_EventBuilderFilter);
auto& algo = I.algo_map.at(iguana::Iguana::clas12_EventBuilderFilter);
algo->Log()->SetLevel("trace");
// algo->Log()->DisableStyle();
algo->SetOption("pids", std::set<int>{11, 211, -211});
Expand All @@ -31,26 +31,24 @@ int main(int argc, char **argv) {
/////////////////////////////////////////////////////

// read input file
hipo::reader reader;
reader.open(inFileName.c_str());

// get bank schema
/* TODO: users should not have to do this; this is a workaround until
* the pattern `hipo::event::getBank("REC::Particle")` is possible
*/
hipo::dictionary factory;
reader.readDictionary(factory);
auto particleBank = std::make_shared<hipo::bank>(factory.getSchema("REC::Particle"));
auto caloBank = std::make_shared<hipo::bank>(factory.getSchema("REC::Calorimeter")); // TODO: remove when not needed (this is for testing)
hipo::reader reader(inFileName.c_str());

// set banks
hipo::banklist banks = reader.getBanks({
"REC::Particle",
"REC::Calorimeter"
});
enum banks_enum { // TODO: make this nicer
b_particle,
b_calo
};

// event loop
hipo::event event;
int iEvent = 0;
while(reader.next(event) && (iEvent++ < numEvents || numEvents == 0)) {
event.getStructure(*particleBank);
printParticles("PIDS BEFORE algo->Run() ", particleBank);
algo->Run({particleBank, caloBank});
printParticles("PIDS AFTER algo->Run() ", particleBank);
while(reader.next(banks) && (iEvent++ < numEvents || numEvents == 0)) {
printParticles("PIDS BEFORE algo->Run() ", banks.at(b_particle));
algo->Run(banks);
printParticles("PIDS AFTER algo->Run() ", banks.at(b_particle));
}

/////////////////////////////////////////////////////
Expand Down
7 changes: 4 additions & 3 deletions src/tests/run_rows.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

int main(int argc, char **argv) {

/* DISABLED until `run_banks` is more stable

// parse arguments
int argi = 1;
std::string inFileName = argc > argi ? std::string(argv[argi++]) : "data.hipo";
int numEvents = argc > argi ? std::stoi(argv[argi++]) : 3;


// start the algorithm
auto algo = std::make_shared<iguana::clas12::EventBuilderFilter>();
algo->SetOption("pids", std::set<int>{11, 211, -211});
Expand All @@ -20,9 +23,6 @@ int main(int argc, char **argv) {
reader.open(inFileName.c_str());

// get bank schema
/* TODO: users should not have to do this; this is a workaround until
* the pattern `hipo::event::getBank("REC::Particle")` is possible
*/
hipo::dictionary factory;
reader.readDictionary(factory);
auto particleBank = std::make_shared<hipo::bank>(factory.getSchema("REC::Particle"));
Expand All @@ -43,5 +43,6 @@ int main(int argc, char **argv) {
/////////////////////////////////////////////////////

algo->Stop();
*/
return 0;
}