Skip to content

Commit

Permalink
Fix reading relations to unpersisted objects (AIDASoft#486)
Browse files Browse the repository at this point in the history
* Switch test case from EventStore to Frame based I/O

* Default initialize ObjectIDs to untracked

* Add check for existence before getting collections

* Properly lock the collection ID table

* Make CollectionIDTable return optional for some queries

* Make sure that each category gets its own id table
  • Loading branch information
tmadlener authored and Ananya2003Gupta committed Sep 26, 2023
1 parent 2502565 commit 3644468
Show file tree
Hide file tree
Showing 20 changed files with 83 additions and 76 deletions.
8 changes: 6 additions & 2 deletions include/podio/CollectionIDTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cstdint>
#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <vector>

Expand All @@ -26,14 +27,17 @@ class CollectionIDTable {
CollectionIDTable(const std::vector<uint32_t>& ids, const std::vector<std::string>& names);

/// return collection ID for given name
uint32_t collectionID(const std::string& name) const;
std::optional<uint32_t> collectionID(const std::string& name) const;

/// return name for given collection ID
const std::string name(uint32_t collectionID) const;
std::optional<const std::string> name(uint32_t collectionID) const;

/// Check if collection name is known
bool present(const std::string& name) const;

/// Check if collection ID is known
bool present(uint32_t collectionID) const;

/// return registered names
const std::vector<std::string>& names() const {
return m_names;
Expand Down
11 changes: 7 additions & 4 deletions include/podio/Frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ podio::CollectionBase* Frame::FrameModel<FrameDataT>::doGet(const std::string& n
}

coll->prepareAfterRead();
coll->setID(m_idTable.collectionID(name));
coll->setID(m_idTable.collectionID(name).value());
{
std::lock_guard mapLock{*m_mapMtx};
auto [it, success] = m_collections.emplace(name, std::move(coll));
Expand All @@ -400,17 +400,20 @@ podio::CollectionBase* Frame::FrameModel<FrameDataT>::doGet(const std::string& n

template <typename FrameDataT>
bool Frame::FrameModel<FrameDataT>::get(uint32_t collectionID, CollectionBase*& collection) const {
const auto& name = m_idTable.name(collectionID);
const auto name = m_idTable.name(collectionID);
if (!name) {
return false;
}
const auto& [_, inserted] = m_retrievedIDs.insert(collectionID);

if (inserted) {
auto coll = doGet(name);
auto coll = doGet(name.value());
if (coll) {
collection = coll;
return true;
}
} else {
auto coll = doGet(name, false);
auto coll = doGet(name.value(), false);
if (coll) {
collection = coll;
return true;
Expand Down
10 changes: 5 additions & 5 deletions include/podio/ObjectID.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ namespace podio {
class ObjectID {

public:
/// index of object in collection
int index;
/// ID of the collection
uint32_t collectionID;

/// not part of a collection
static const int untracked = -1;
/// invalid or non-available object
static const int invalid = -2;

/// index of object in collection
int index{untracked};
/// ID of the collection
uint32_t collectionID{static_cast<uint32_t>(untracked)};

/// index and collectionID uniquely defines the object.
/// this operator is necessary for meaningful comparisons in python
bool operator==(const ObjectID& other) const {
Expand Down
2 changes: 1 addition & 1 deletion include/podio/ROOTNTupleReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class ROOTNTupleReader {

std::vector<std::string> m_availableCategories{};

std::shared_ptr<podio::CollectionIDTable> m_table{};
std::unordered_map<std::string, std::shared_ptr<podio::CollectionIDTable>> m_idTables{};
};

} // namespace podio
Expand Down
4 changes: 2 additions & 2 deletions python/templates/Collection.cc.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@

{% with collection_type = class.bare_type + 'Collection' %}
{{ collection_type }}::{{ collection_type }}() :
m_isValid(false), m_isPrepared(false), m_isSubsetColl(false), m_collectionID(0), m_storageMtx(std::make_unique<std::mutex>()), m_storage() {}
m_isValid(false), m_isPrepared(false), m_isSubsetColl(false), m_collectionID(podio::ObjectID::untracked), m_storageMtx(std::make_unique<std::mutex>()), m_storage() {}

{{ collection_type }}::{{ collection_type }}({{ collection_type }}Data&& data, bool isSubsetColl) :
m_isValid(false), m_isPrepared(false), m_isSubsetColl(isSubsetColl), m_collectionID(0), m_storageMtx(std::make_unique<std::mutex>()), m_storage(std::move(data)) {}
m_isValid(false), m_isPrepared(false), m_isSubsetColl(isSubsetColl), m_collectionID(podio::ObjectID::untracked), m_storageMtx(std::make_unique<std::mutex>()), m_storage(std::move(data)) {}

{{ collection_type }}::~{{ collection_type }}() {
// Need to tell the storage how to clean-up
Expand Down
4 changes: 2 additions & 2 deletions python/templates/Obj.cc.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
{{ utils.namespace_open(class.namespace) }}
{% with obj_type = class.bare_type + 'Obj' %}
{{ obj_type }}::{{ obj_type }}() :
{% raw %} ObjBase{{podio::ObjectID::untracked, 0}, 0}{% endraw %},
{% raw %} ObjBase{{}, 0}{% endraw %},
data(){{ single_relations_initialize(OneToOneRelations) }}
{%- for relation in OneToManyRelations + VectorMembers %},
m_{{ relation.name }}(new std::vector<{{ relation.full_type }}>())
Expand All @@ -29,7 +29,7 @@
{ }

{{ obj_type }}::{{ obj_type }}(const {{ obj_type }}& other) :
{% raw %} ObjBase{{podio::ObjectID::untracked, 0}, 0}{% endraw %},
{% raw %} ObjBase{{}, 0}{% endraw %},
data(other.data){{ single_relations_initialize(OneToOneRelations) }}
{%- for relation in OneToManyRelations + VectorMembers %},
m_{{ relation.name }}(new std::vector<{{ relation.full_type }}>(*(other.m_{{ relation.name }})))
Expand Down
2 changes: 1 addition & 1 deletion python/templates/macros/implementations.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ const podio::ObjectID {{ full_type }}::getObjectID() const {
if (m_obj) {
return m_obj->id;
}
return podio::ObjectID{podio::ObjectID::invalid, 0};
return podio::ObjectID{};
}

{% set inverse_type = class.bare_type if prefix else 'Mutable' + class.bare_type %}
Expand Down
31 changes: 20 additions & 11 deletions src/CollectionIDTable.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
// podio specific includes
#include "podio/CollectionIDTable.h"
#include <algorithm>
#include <iostream>

#include "MurmurHash3.h"

#include <algorithm>
#include <iostream>

namespace podio {

CollectionIDTable::CollectionIDTable() : m_mutex(std::make_unique<std::mutex>()) {
Expand All @@ -18,36 +19,44 @@ CollectionIDTable::CollectionIDTable(const std::vector<uint32_t>& ids, const std
m_collectionIDs(ids), m_names(names), m_mutex(std::make_unique<std::mutex>()) {
}

const std::string CollectionIDTable::name(uint32_t ID) const {
std::lock_guard<std::mutex> lock(*m_mutex);
std::optional<const std::string> CollectionIDTable::name(uint32_t ID) const {
std::lock_guard<std::mutex> lock{*m_mutex};
const auto result = std::find(begin(m_collectionIDs), end(m_collectionIDs), ID);
const auto index = std::distance(m_collectionIDs.begin(), result);
if (index >= static_cast<ptrdiff_t>(m_names.size())) {
return std::nullopt;
}
return m_names[index];
}

uint32_t CollectionIDTable::collectionID(const std::string& name) const {
std::lock_guard<std::mutex> lock(*m_mutex);
std::optional<uint32_t> CollectionIDTable::collectionID(const std::string& name) const {
std::lock_guard<std::mutex> lock{*m_mutex};
const auto result = std::find(begin(m_names), end(m_names), name);
const auto index = std::distance(m_names.begin(), result);
if (index >= static_cast<ptrdiff_t>(m_collectionIDs.size())) {
return std::nullopt;
}
return m_collectionIDs[index];
}

void CollectionIDTable::print() const {
std::lock_guard<std::mutex> lock(*m_mutex);
std::lock_guard<std::mutex> lock{*m_mutex};
std::cout << "CollectionIDTable" << std::endl;
for (unsigned i = 0; i < m_names.size(); ++i) {
std::cout << "\t" << m_names[i] << " : " << m_collectionIDs[i] << std::endl;
}
}

bool CollectionIDTable::present(const std::string& name) const {
std::lock_guard<std::mutex> lock(*m_mutex);
const auto result = std::find(begin(m_names), end(m_names), name);
return result != end(m_names);
return collectionID(name).has_value();
}

bool CollectionIDTable::present(uint32_t collectionID) const {
return name(collectionID).has_value();
}

uint32_t CollectionIDTable::add(const std::string& name) {
std::lock_guard<std::mutex> lock(*m_mutex);
std::lock_guard<std::mutex> lock{*m_mutex};
const auto result = std::find(begin(m_names), end(m_names), name);
uint32_t ID = 0;
if (result == m_names.end()) {
Expand Down
4 changes: 2 additions & 2 deletions src/EventStore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ bool EventStore::get(uint32_t id, CollectionBase*& collection) const {
bool success = false;
if (val.second == true) {
// collection not yet retrieved in recursive-call
auto name = m_table->name(id);
auto name = m_table->name(id).value();
success = doGet(name, collection, true);
} else {
// collection already requested in recursive call
// do not set the references to break collection dependency-cycle
auto name = m_table->name(id);
auto name = m_table->name(id).value();
success = doGet(name, collection, false);
}
// fg: the set should only be cleared at the end of event (in clear() ) ...
Expand Down
4 changes: 2 additions & 2 deletions src/ROOTFrameReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ createCollectionBranchesIndexBased(TChain* chain, const podio::CollectionIDTable
for (const auto& [collID, collType, isSubsetColl, collSchemaVersion] : collInfo) {
// We only write collections that are in the collectionIDTable, so no need
// to check here
const auto name = idTable.name(collID);
const auto name = idTable.name(collID).value();

const auto collectionClass = TClass::GetClass(collType.c_str());
// Need the collection here to setup all the branches. Have to manage the
Expand Down Expand Up @@ -315,7 +315,7 @@ createCollectionBranches(TChain* chain, const podio::CollectionIDTable& idTable,
for (const auto& [collID, collType, isSubsetColl, collSchemaVersion] : collInfo) {
// We only write collections that are in the collectionIDTable, so no need
// to check here
const auto name = idTable.name(collID);
const auto name = idTable.name(collID).value();

root_utils::CollectionBranches branches{};
if (isSubsetColl) {
Expand Down
4 changes: 2 additions & 2 deletions src/ROOTFrameWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ void ROOTFrameWriter::initBranches(CategoryInfo& catInfo, const std::vector<Stor
}

catInfo.branches.push_back(branches);
catInfo.collInfo.emplace_back(catInfo.idTable.collectionID(name), coll->getTypeName(), coll->isSubsetCollection(),
coll->getSchemaVersion());
catInfo.collInfo.emplace_back(catInfo.idTable.collectionID(name).value(), coll->getTypeName(),
coll->isSubsetCollection(), coll->getSchemaVersion());
}

// Also make branches for the parameters
Expand Down
2 changes: 1 addition & 1 deletion src/ROOTLegacyReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ void ROOTLegacyReader::createCollectionBranches(const std::vector<root_utils::Co
for (const auto& [collID, collType, isSubsetColl, collSchemaVersion] : collInfo) {
// We only write collections that are in the collectionIDTable, so no need
// to check here
const auto name = m_table->name(collID);
const auto name = m_table->name(collID).value();

root_utils::CollectionBranches branches{};
const auto collectionClass = TClass::GetClass(collType.c_str());
Expand Down
8 changes: 4 additions & 4 deletions src/ROOTNTupleReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ bool ROOTNTupleReader::initCategory(const std::string& category) {
auto schemaVersion = m_metadata_readers[filename]->GetView<std::vector<SchemaVersionT>>("schemaVersion_" + category);
m_collectionInfo[category].schemaVersion = schemaVersion(0);

m_idTables[category] =
std::make_shared<CollectionIDTable>(m_collectionInfo[category].id, m_collectionInfo[category].name);

return true;
}

Expand Down Expand Up @@ -176,11 +179,8 @@ std::unique_ptr<ROOTFrameData> ROOTNTupleReader::readEntry(const std::string& ca
m_readers[category][0]->LoadEntry(entNum);

auto parameters = readEventMetaData(category, entNum);
if (!m_table) {
m_table = std::make_shared<CollectionIDTable>(m_collectionInfo[category].id, m_collectionInfo[category].name);
}

return std::make_unique<ROOTFrameData>(std::move(buffers), m_table, std::move(parameters));
return std::make_unique<ROOTFrameData>(std::move(buffers), m_idTables[category], std::move(parameters));
}

} // namespace podio
4 changes: 2 additions & 2 deletions src/ROOTReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ CollectionBase* ROOTReader::readCollectionData(const root_utils::CollectionBranc
}

// do the unpacking
const auto id = m_table->collectionID(name);
const auto id = m_table->collectionID(name).value();
collection->setID(id);
collection->prepareAfterRead();

Expand Down Expand Up @@ -249,7 +249,7 @@ void ROOTReader::createCollectionBranches(const std::vector<root_utils::Collecti
for (const auto& [collID, collType, isSubsetColl, collSchemaVersion] : collInfo) {
// We only write collections that are in the collectionIDTable, so no need
// to check here
const auto name = m_table->name(collID);
const auto name = m_table->name(collID).value();

root_utils::CollectionBranches branches{};
const auto collectionClass = TClass::GetClass(collType.c_str());
Expand Down
2 changes: 1 addition & 1 deletion src/ROOTWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void ROOTWriter::finish() {
std::vector<root_utils::CollectionInfoT> collectionInfo;
collectionInfo.reserve(m_collectionsToWrite.size());
for (const auto& name : m_collectionsToWrite) {
const auto collID = collIDTable->collectionID(name);
const auto collID = collIDTable->collectionID(name).value();
const podio::CollectionBase* coll{nullptr};
// No check necessary, only registered collections possible
m_store->get(name, coll);
Expand Down
2 changes: 1 addition & 1 deletion src/SIOBlock.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ SIOCollectionIDTableBlock::SIOCollectionIDTableBlock(podio::EventStore* store) :
if (!store->get(id, tmp)) {
std::cerr
<< "PODIO-ERROR cannot construct CollectionIDTableBlock because a collection is missing from the store (id: "
<< id << ", name: " << table->name(id) << ")" << std::endl;
<< id << ", name: " << table->name(id).value_or("<not available>") << ")" << std::endl;
}

_types.emplace_back(tmp->getValueTypeName());
Expand Down
2 changes: 1 addition & 1 deletion src/SIOFrameData.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ std::vector<std::string> SIOFrameData::getAvailableCollections() {
// no guarantee that it coincides with the index in the blocks.
// Additionally, collection indices start at 1
const auto collID = m_idTable.ids()[i - 1];
collections.push_back(m_idTable.name(collID));
collections.push_back(m_idTable.name(collID).value());
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/SIOReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ CollectionBase* SIOReader::readCollection(const std::string& name) {
std::find_if(begin(m_inputs), end(m_inputs), [&name](const SIOReader::Input& t) { return t.second == name; });

if (p != end(m_inputs)) {
p->first->setID(m_table->collectionID(name));
p->first->setID(m_table->collectionID(name).value());
p->first->prepareAfterRead();
return p->first;
}
Expand Down
2 changes: 1 addition & 1 deletion src/sioUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace sio_utils {

for (const auto& [name, coll] : collections) {
names.emplace_back(name);
ids.emplace_back(collIdTable.collectionID(name));
ids.emplace_back(collIdTable.collectionID(name).value());
types.emplace_back(coll->getValueTypeName());
subsetColl.emplace_back(coll->isSubsetCollection());
}
Expand Down
Loading

0 comments on commit 3644468

Please sign in to comment.