Skip to content

Commit

Permalink
feat: add event builder filter algorithm (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-dilks committed Nov 21, 2023
1 parent 74d085c commit 79505c8
Show file tree
Hide file tree
Showing 14 changed files with 257 additions and 33 deletions.
5 changes: 3 additions & 2 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ project(
'iguana',
'cpp',
version: '0.0.0',
license: 'LGPLv3'
license: 'LGPLv3',
default_options: [ 'cpp_std=c++20' ],
)

project_inc = include_directories('src')
Expand All @@ -20,6 +21,6 @@ fmt_dep = dependency('fmt')
hipo_dep = dependency('hipo4', method: 'cmake', cmake_args: '-DCMAKE_PREFIX_PATH=' + get_option('hipo'))

subdir('src/services')
subdir('src/algorithms/clas12/fiducial_cuts')
subdir('src/algorithms')
subdir('src/iguana')
subdir('src/tests')
70 changes: 70 additions & 0 deletions src/algorithms/clas12/event_builder_filter/EventBuilderFilter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include "EventBuilderFilter.h"

namespace iguana::clas12 {

void EventBuilderFilter::Start() {
m_log->Debug("START {}", m_name);

// set configuration
m_log->SetLevel(Logger::Level::trace);
m_opt.mode = EventBuilderFilterOptions::Modes::blank;
m_opt.pids = {11, 211, -211};
}


Algorithm::BankMap EventBuilderFilter::Run(Algorithm::BankMap inBanks) {
m_log->Debug("RUN {}", m_name);

// check the input banks existence
if(MissingInputBanks(inBanks, {"particles"}))
Throw("missing input banks");

// define the output schemata and banks
BankMap outBanks = {
{ "particles", hipo::bank(inBanks.at("particles").getSchema()) }
};

// filter the input bank for requested PDG code(s)
std::set<int> acceptedRows;
for(int row = 0; row < inBanks.at("particles").getRows(); row++) {
auto pid = inBanks.at("particles").get("pid", row);
auto accept = m_opt.pids.contains(pid);
if(accept) acceptedRows.insert(row);
m_log->Debug("input PID {} -- accept = {}", pid, accept);
}

// fill the output bank
switch(m_opt.mode) {

case EventBuilderFilterOptions::Modes::blank:
outBanks.at("particles").setRows(inBanks.at("particles").getRows());
for(int row = 0; row < inBanks.at("particles").getRows(); row++) {
if(acceptedRows.contains(row))
CopyBankRow(inBanks.at("particles"), row, outBanks.at("particles"), row);
else
BlankRow(outBanks.at("particles"), row);
}
break;

case EventBuilderFilterOptions::Modes::compact:
outBanks.at("particles").setRows(acceptedRows.size());
for(int row = 0; auto acceptedRow : acceptedRows)
CopyBankRow(inBanks.at("particles"), acceptedRow, outBanks.at("particles"), row++);
break;

default:
Throw("unknown 'mode' option");

}

// dump the banks and return the output
ShowBanks(inBanks, outBanks);
return outBanks;
}


void EventBuilderFilter::Stop() {
m_log->Debug("STOP {}", m_name);
}

}
30 changes: 30 additions & 0 deletions src/algorithms/clas12/event_builder_filter/EventBuilderFilter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include "services/Algorithm.h"

namespace iguana::clas12 {

class EventBuilderFilterOptions {
public:
enum Modes { blank, compact };
Modes mode = blank;
std::set<int> pids = {11, 211};
};


class EventBuilderFilter : public Algorithm {

public:
EventBuilderFilter() : Algorithm("event_builder_filter") {}
~EventBuilderFilter() {}

void Start() override;
Algorithm::BankMap Run(Algorithm::BankMap inBanks) override;
void Stop() override;

private:
EventBuilderFilterOptions m_opt;

};

}
3 changes: 3 additions & 0 deletions src/algorithms/clas12/event_builder_filter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Event Builder Filter

Filters a particle bank for specific Event Builder PDGs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
algo_headers = [
'FiducialCuts.h',
'clas12/event_builder_filter/EventBuilderFilter.h',
]

algo_sources = [
'FiducialCuts.cc',
'clas12/event_builder_filter/EventBuilderFilter.cc',
]

algo_lib = shared_library(
Expand Down
2 changes: 1 addition & 1 deletion src/iguana/Arbiter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace iguana {

Arbiter::Arbiter() {
algo_map.insert({clas12_FiducialCuts, std::make_shared<clas12::FiducialCuts>()});
algo_map.insert({clas12_EventBuilderFilter, std::make_shared<clas12::EventBuilderFilter>()});
}

}
4 changes: 2 additions & 2 deletions src/iguana/Arbiter.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <memory>

// TODO: avoid listing the algos
#include "algorithms/clas12/fiducial_cuts/FiducialCuts.h"
#include "algorithms/clas12/event_builder_filter/EventBuilderFilter.h"

namespace iguana {

Expand All @@ -18,7 +18,7 @@ namespace iguana {
// TODO: avoid listing the algos
// TODO: who should own the algorithm instances: Arbiter or the user?
enum algo {
clas12_FiducialCuts
clas12_EventBuilderFilter
};

// TODO: make private
Expand Down
1 change: 1 addition & 0 deletions src/iguana/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ iguana_lib = shared_library(
'Iguana',
iguana_sources,
include_directories: project_inc,
dependencies: [ fmt_dep, hipo_dep ],
link_with: [ algo_lib, services_lib ],
install: true,
install_dir: project_lib_install_dir,
Expand Down
50 changes: 48 additions & 2 deletions src/services/Algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,54 @@

namespace iguana {

Algorithm::Algorithm(std::string name) {
m_log = std::make_shared<Logger>(name);
Algorithm::Algorithm(std::string name) : m_name(name) {
m_log = std::make_shared<Logger>(m_name);
}

bool Algorithm::MissingInputBanks(BankMap banks, std::set<std::string> keys) {
for(auto key : keys) {
if(!banks.contains(key)) {
m_log->Error("Algorithm '{}' is missing the input bank '{}'", m_name, key);
m_log->Error(" => the following input banks are required by '{}':", m_name);
for(auto k : keys)
m_log->Error(" - {}", k);
return true;
}
}
return false;
}

void Algorithm::CopyBankRow(hipo::bank srcBank, int srcRow, hipo::bank destBank, int destRow) {
// TODO: check srcBank.getSchema() == destBank.getSchema()
for(int item = 0; item < srcBank.getSchema().getEntries(); item++) {
auto val = srcBank.get(item, srcRow);
destBank.put(item, destRow, val);
}
}

void Algorithm::BlankRow(hipo::bank bank, int row) {
for(int item = 0; item < bank.getSchema().getEntries(); item++) {
bank.put(item, row, 0);
}
}

void Algorithm::ShowBanks(BankMap banks, std::string message, Logger::Level level) {
if(m_log->GetLevel() <= level) {
m_log->Print(level, message);
for(auto [key,bank] : banks) {
m_log->Print(level, "BANK: '{}'", key);
bank.show();
}
}
}

void Algorithm::ShowBanks(BankMap inBanks, BankMap outBanks, Logger::Level level) {
ShowBanks(inBanks, "===== INPUT BANKS =====", level);
ShowBanks(outBanks, "===== OUTPUT BANKS =====", level);
}

void Algorithm::Throw(std::string message) {
throw std::runtime_error(fmt::format("CRITICAL ERROR: {}; Algorithm '{}' stopped!", message, m_name));
}

}
60 changes: 58 additions & 2 deletions src/services/Algorithm.h
Original file line number Diff line number Diff line change
@@ -1,19 +1,75 @@
#pragma once

#include "Logger.h"
#include <hipo4/bank.h>
#include <set>

namespace iguana {

class Algorithm {

public:

using BankMap = std::unordered_map<std::string, hipo::bank>;

/// Algorithm base class constructor
/// @param name the unique name for a derived class instance
Algorithm(std::string name);

/// Algorithm base class destructor
virtual ~Algorithm() {}

/// Initialize an algorithm before any events are processed
virtual void Start() = 0;
virtual int Run(int a, int b) = 0;

/// Run an algorithm
/// @param inBanks the set of input banks
/// @return a set of output banks
virtual BankMap Run(BankMap inBanks) = 0;

/// Finalize an algorithm after all events are processed
virtual void Stop() = 0;
virtual ~Algorithm() {}

protected:

/// Check if `banks` contains all keys `keys`; this is useful for checking algorithm inputs are complete.
/// @param banks the set of (key,bank) pairs to check
/// @keys the required keys
/// @return true if `banks` is missing any keys in `keys`
bool MissingInputBanks(BankMap banks, std::set<std::string> keys);

/// Copy a row from one bank to another, assuming their schemata are equivalent
/// @param srcBank the source bank
/// @param srcRow the row in `srcBank` to copy from
/// @param destBank the destination bank
/// @param destRow the row in `destBank` to copy to
void CopyBankRow(hipo::bank srcBank, int srcRow, hipo::bank destBank, int destRow);

/// Blank a row, setting all items to zero
/// @param bank the bank to modify
/// @param row the row to blank
void BlankRow(hipo::bank bank, int row);

/// Dump all banks in a BankMap
/// @param banks the banks to show
/// @param message optionally print a header message
/// @param level the log level
void ShowBanks(BankMap banks, std::string message="", Logger::Level level=Logger::trace);

/// Dump all input and output banks
/// @param inBanks the input banks
/// @param outBanks the output banks
/// @param level the log level
void ShowBanks(BankMap inBanks, BankMap outBanks, Logger::Level level=Logger::trace);

/// Stop the algorithm and throw a runtime exception
/// @param message the error message
void Throw(std::string message);

/// algorithm name
std::string m_name;

/// Logger
std::shared_ptr<Logger> m_log;
};
}
4 changes: 4 additions & 0 deletions src/services/Logger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ namespace iguana {
Debug("Logger '{}' set to '{}'", m_name, m_level_names.at(m_level));
}

Logger::Level Logger::GetLevel() {
return m_level;
}

}
31 changes: 14 additions & 17 deletions src/services/Logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,27 @@ namespace iguana {
~Logger() {}

void SetLevel(Level lev);
Level GetLevel();

template <typename... VALUES> void Trace(std::string msg, VALUES... vals) { Print(trace, msg, vals...); }
template <typename... VALUES> void Debug(std::string msg, VALUES... vals) { Print(debug, msg, vals...); }
template <typename... VALUES> void Info(std::string msg, VALUES... vals) { Print(info, msg, vals...); }
template <typename... VALUES> void Warn(std::string msg, VALUES... vals) { Print(warn, msg, vals...); }
template <typename... VALUES> void Error(std::string msg, VALUES... vals) { Print(error, msg, vals...); }
template <typename... VALUES> void Trace(std::string message, VALUES... vals) { Print(trace, message, vals...); }
template <typename... VALUES> void Debug(std::string message, VALUES... vals) { Print(debug, message, vals...); }
template <typename... VALUES> void Info(std::string message, VALUES... vals) { Print(info, message, vals...); }
template <typename... VALUES> void Warn(std::string message, VALUES... vals) { Print(warn, message, vals...); }
template <typename... VALUES> void Error(std::string message, VALUES... vals) { Print(error, message, vals...); }

template <typename... VALUES>
void Print(Level lev, std::string msg, VALUES... vals) {
void Print(Level lev, std::string message, VALUES... vals) {
if(lev >= m_level) {
auto level_name_it = m_level_names.find(lev);
if(level_name_it == m_level_names.end()) {
Warn("Logger::Print called with unknown log level '{}'; printing as error instead", static_cast<int>(lev)); // FIXME: static_cast -> fmt::underlying, but needs new version of fmt
Error(msg, vals...);
} else {
if(m_level_names.contains(lev)) {
auto prefix = fmt::format("[{}] [{}] ", m_level_names.at(lev), m_name);
fmt::print(
lev >= warn ? stderr : stdout,
fmt::format(
"[{}] [{}] {}\n",
level_name_it->second,
m_name,
fmt::format(msg, vals...)
)
fmt::runtime(prefix + message + "\n"),
vals...
);
} else {
Warn("Logger::Print called with unknown log level '{}'; printing as error instead", static_cast<int>(lev)); // FIXME: static_cast -> fmt::underlying, but needs new version of fmt
Error(message, vals...);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/services/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ services_lib = shared_library(
'IguanaServices',
services_sources,
include_directories: project_inc,
dependencies: fmt_dep,
dependencies: [ fmt_dep, hipo_dep ],
install: true,
install_dir: project_lib_install_dir,
install_rpath: project_lib_rpath,
)

services_dep = declare_dependency(
dependencies: fmt_dep
dependencies: [ fmt_dep, hipo_dep ]
)

install_headers(services_headers, subdir : meson.project_name())
Loading

0 comments on commit 79505c8

Please sign in to comment.