Skip to content

Commit

Permalink
Use std::unique_ptr<podio::CollectionBase> to save collections in t…
Browse files Browse the repository at this point in the history
…he store and fix leak in the Writer (#250)
  • Loading branch information
jmcarcell authored Oct 28, 2024
1 parent 33b97da commit 692b979
Show file tree
Hide file tree
Showing 11 changed files with 125 additions and 142 deletions.
32 changes: 13 additions & 19 deletions k4FWCore/components/CollectionMerger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@
#include "k4FWCore/Transformer.h"

#include <map>
#include <memory>
#include <string>
#include <string_view>

struct CollectionMerger final : k4FWCore::Transformer<std::shared_ptr<podio::CollectionBase>(
const std::vector<const std::shared_ptr<podio::CollectionBase>*>&)> {
struct CollectionMerger final
: k4FWCore::Transformer<podio::CollectionBase*(const std::vector<const podio::CollectionBase*>&)> {
CollectionMerger(const std::string& name, ISvcLocator* svcLoc)
: Transformer(name, svcLoc, {KeyValues("InputCollections", {"MCParticles"})},
{KeyValues("OutputCollection", {"NewMCParticles"})}) {
Expand Down Expand Up @@ -91,43 +90,38 @@ struct CollectionMerger final : k4FWCore::Transformer<std::shared_ptr<podio::Col
&CollectionMerger::mergeCollections<edm4hep::GeneratorPdfInfoCollection>;
}

std::shared_ptr<podio::CollectionBase> operator()(
const std::vector<const std::shared_ptr<podio::CollectionBase>*>& input) const override {
std::shared_ptr<podio::CollectionBase> ret;
podio::CollectionBase* operator()(const std::vector<const podio::CollectionBase*>& input) const override {
podio::CollectionBase* ret = nullptr;
debug() << "Merging " << input.size() << " collections" << endmsg;
std::string_view type = "";
for (const auto& coll : input) {
debug() << "Merging collection of type " << (*coll)->getTypeName() << " with " << (*coll)->size() << " elements"
debug() << "Merging collection of type " << coll->getTypeName() << " with " << coll->size() << " elements"
<< endmsg;
if (type.empty()) {
type = (*coll)->getTypeName();
} else if (type != (*coll)->getTypeName()) {
type = coll->getTypeName();
} else if (type != coll->getTypeName()) {
throw std::runtime_error("Different collection types are not supported");
return ret;
}
(this->*m_map.at((*coll)->getTypeName()))(*coll, ret);
(this->*m_map.at(coll->getTypeName()))(coll, ret);
}
return ret;
}

private:
using MergeType = void (CollectionMerger::*)(const std::shared_ptr<podio::CollectionBase>&,
std::shared_ptr<podio::CollectionBase>&) const;
using MergeType = void (CollectionMerger::*)(const podio::CollectionBase*, podio::CollectionBase*&) const;
std::map<std::string_view, MergeType> m_map;
Gaudi::Property<bool> m_copy{this, "Copy", false,
"Copy the elements of the collections instead of creating a subset collection"};

template <typename T>
void mergeCollections(const std::shared_ptr<podio::CollectionBase>& source,
std::shared_ptr<podio::CollectionBase>& ret) const {
template <typename T> void mergeCollections(const podio::CollectionBase* source, podio::CollectionBase*& ret) const {
if (!ret) {
ret = std::make_shared<T>();
ret = new T();
if (!m_copy) {
ret->setSubsetCollection();
}
}
const auto ptr = std::static_pointer_cast<T>(ret);
const auto sourceColl = std::static_pointer_cast<T>(source);
const auto ptr = static_cast<T*>(ret);
const auto sourceColl = static_cast<const T*>(source);
if (m_copy) {
for (const auto& elem : *sourceColl) {
ptr->push_back(elem.clone());
Expand Down
5 changes: 2 additions & 3 deletions k4FWCore/components/IIOSvc.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ class IIOSvc : virtual public IInterface {
* @brief Read the next event from the input file
* @return A tuple containing the collections read, the collection names and the frame that owns the collections
*/
virtual std::tuple<std::vector<std::shared_ptr<podio::CollectionBase>>, std::vector<std::string>, podio::Frame>
next() = 0;
virtual std::shared_ptr<std::vector<std::string>> getCollectionNames() const = 0;
virtual std::tuple<std::vector<podio::CollectionBase*>, std::vector<std::string>, podio::Frame> next() = 0;
virtual std::shared_ptr<std::vector<std::string>> getCollectionNames() const = 0;

virtual podio::Writer& getWriter() = 0;
virtual void deleteWriter() = 0;
Expand Down
16 changes: 9 additions & 7 deletions k4FWCore/components/IOSvc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,14 @@ StatusCode IOSvc::initialize() {

StatusCode IOSvc::finalize() { return Service::finalize(); }

std::tuple<std::vector<std::shared_ptr<podio::CollectionBase>>, std::vector<std::string>, podio::Frame> IOSvc::next() {
std::tuple<std::vector<podio::CollectionBase*>, std::vector<std::string>, podio::Frame> IOSvc::next() {
podio::Frame frame;
{
std::scoped_lock<std::mutex> lock(m_changeBufferLock);
std::lock_guard<std::mutex> lock(m_changeBufferLock);
if (m_nextEntry < m_entries) {
frame = podio::Frame(m_reader->readEvent(m_nextEntry));
} else {
return std::make_tuple(std::vector<std::shared_ptr<podio::CollectionBase>>(), std::vector<std::string>(),
std::move(frame));
return std::make_tuple(std::vector<podio::CollectionBase*>(), std::vector<std::string>(), std::move(frame));
}
m_nextEntry++;
if (m_collectionNames.empty()) {
Expand All @@ -134,11 +133,11 @@ std::tuple<std::vector<std::shared_ptr<podio::CollectionBase>>, std::vector<std:
}
}

std::vector<std::shared_ptr<podio::CollectionBase>> collections;
std::vector<podio::CollectionBase*> collections;

for (const auto& name : m_collectionNames) {
auto ptr = const_cast<podio::CollectionBase*>(frame.get(name));
collections.push_back(std::shared_ptr<podio::CollectionBase>(ptr));
collections.push_back(ptr);
}

return std::make_tuple(collections, m_collectionNames, std::move(frame));
Expand Down Expand Up @@ -178,7 +177,10 @@ void IOSvc::handle(const Incident& incident) {
code = m_dataSvc->retrieveObject("/Event/" + coll, collPtr);
if (code.isSuccess()) {
debug() << "Removing the collection: " << coll << " from the store" << endmsg;
code = m_dataSvc->unregisterObject(collPtr);
code = m_dataSvc->unregisterObject(collPtr);
auto storePtr = dynamic_cast<AnyDataWrapper<std::unique_ptr<podio::CollectionBase>>*>(collPtr);
storePtr->getData().release();
delete storePtr;
} else {
error() << "Expected collection " << coll << " in the store but it was not found" << endmsg;
}
Expand Down
3 changes: 1 addition & 2 deletions k4FWCore/components/IOSvc.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ class IOSvc : public extends<Service, IIOSvc, IIncidentListener> {
StatusCode initialize() override;
StatusCode finalize() override;

std::tuple<std::vector<std::shared_ptr<podio::CollectionBase>>, std::vector<std::string>, podio::Frame> next()
override;
std::tuple<std::vector<podio::CollectionBase*>, std::vector<std::string>, podio::Frame> next() override;

std::shared_ptr<std::vector<std::string>> getCollectionNames() const override {
return std::make_shared<std::vector<std::string>>(m_collectionNames);
Expand Down
21 changes: 9 additions & 12 deletions k4FWCore/components/Reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <GaudiKernel/SmartIF.h>
#include "Gaudi/Functional/details.h"
#include "Gaudi/Functional/utilities.h"
#include "GaudiKernel/AnyDataWrapper.h"
#include "GaudiKernel/FunctionalFilterDecision.h"
#include "GaudiKernel/IDataProviderSvc.h"
#include "GaudiKernel/SmartIF.h"
#include "GaudiKernel/StatusCode.h"

#include "podio/CollectionBase.h"
Expand All @@ -32,18 +32,16 @@

#include <memory>

template <typename Container> using vector_of_ = std::vector<Container>;

class CollectionPusher : public Gaudi::Functional::details::BaseClass_t<Gaudi::Functional::Traits::useDefaults> {
using Traits_ = Gaudi::Functional::Traits::useDefaults;
using Out = std::shared_ptr<podio::CollectionBase>;
using Out = std::unique_ptr<podio::CollectionBase>;
using base_class = Gaudi::Functional::details::BaseClass_t<Traits_>;
static_assert(std::is_base_of_v<Algorithm, base_class>, "BaseClass must inherit from Algorithm");

template <typename T>
using OutputHandle_t = Gaudi::Functional::details::OutputHandle_t<Traits_, std::remove_pointer_t<T>>;
std::vector<OutputHandle_t<std::shared_ptr<podio::CollectionBase>>> m_outputs;
Gaudi::Property<std::vector<std::string>> m_inputCollections{
std::vector<OutputHandle_t<Out>> m_outputs;
Gaudi::Property<std::vector<std::string>> m_inputCollections{
this, "InputCollections", {"First collection"}, "List of input collections"};
// Gaudi::Property<std::string> m_input{this, "Input", "Event", "Input file"};

Expand Down Expand Up @@ -73,7 +71,7 @@ class CollectionPusher : public Gaudi::Functional::details::BaseClass_t<Gaudi::F
try {
auto out = (*this)();

auto outColls = std::get<std::vector<std::shared_ptr<podio::CollectionBase>>>(out);
auto outColls = std::get<std::vector<podio::CollectionBase*>>(out);
auto outputLocations = std::get<std::vector<std::string>>(out);

// if (out.size() != m_outputs.size()) {
Expand All @@ -82,7 +80,7 @@ class CollectionPusher : public Gaudi::Functional::details::BaseClass_t<Gaudi::F
// this->name(), StatusCode::FAILURE);
// }
for (size_t i = 0; i != outColls.size(); ++i) {
m_outputs[i].put(std::move(outColls[i]));
m_outputs[i].put(std::unique_ptr<podio::CollectionBase>(outColls[i]));
}
return Gaudi::Functional::FilterDecision::PASSED;
} catch (GaudiException& e) {
Expand All @@ -91,7 +89,7 @@ class CollectionPusher : public Gaudi::Functional::details::BaseClass_t<Gaudi::F
}
}

virtual std::tuple<vector_of_<Out>, std::vector<std::string>> operator()() const = 0;
virtual std::tuple<std::vector<podio::CollectionBase*>, std::vector<std::string>> operator()() const = 0;

private:
ServiceHandle<IDataProviderSvc> m_dataSvc{this, "EventDataSvc", "EventDataSvc"};
Expand All @@ -108,7 +106,7 @@ class Reader final : public CollectionPusher {
// Gaudi doesn't run the destructor of the Services so we have to
// manually ask for the reader to be deleted so it will call finish()
// See https://gitlab.cern.ch/gaudi/Gaudi/-/issues/169
~Reader() { iosvc->deleteReader(); }
~Reader() override { iosvc->deleteReader(); }

ServiceHandle<IIOSvc> iosvc{this, "IOSvc", "IOSvc"};

Expand All @@ -131,8 +129,7 @@ class Reader final : public CollectionPusher {
// The IOSvc takes care of reading and passing the data
// By convention the Frame is pushed to the store
// so that it's deleted at the right time
std::tuple<std::vector<std::shared_ptr<podio::CollectionBase>>, std::vector<std::string>> operator()()
const override {
std::tuple<std::vector<podio::CollectionBase*>, std::vector<std::string>> operator()() const override {
auto val = iosvc->next();

auto eds = eventSvc().as<IDataProviderSvc>();
Expand Down
34 changes: 20 additions & 14 deletions k4FWCore/components/Writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "GaudiKernel/IDataProviderSvc.h"
#include "GaudiKernel/SmartDataPtr.h"
#include "GaudiKernel/StatusCode.h"

#include "podio/Frame.h"

#include "IIOSvc.h"
Expand Down Expand Up @@ -185,18 +184,22 @@ class Writer final : public Gaudi::Functional::Consumer<void(const EventContext&
}
}

DataObject* p;
StatusCode code = m_dataSvc->retrieveObject("/Event" + k4FWCore::frameLocation, p);
AnyDataWrapper<podio::Frame>* ptr;
DataObject* p;
StatusCode code = m_dataSvc->retrieveObject("/Event" + k4FWCore::frameLocation, p);
std::unique_ptr<AnyDataWrapper<podio::Frame>> ptr;
// This is the case when we are reading from a file
// Putting it into a unique_ptr will make sure it's deleted
if (code.isSuccess()) {
m_dataSvc->unregisterObject(p).ignore();
ptr = dynamic_cast<AnyDataWrapper<podio::Frame>*>(p);
auto sc = m_dataSvc->unregisterObject(p);
if (!sc.isSuccess()) {
error() << "Failed to unregister object" << endmsg;
return;
}
ptr = std::unique_ptr<AnyDataWrapper<podio::Frame>>(dynamic_cast<AnyDataWrapper<podio::Frame>*>(p));
}
// This is the case when no reading is being done
// Will be deleted by the store
else {
ptr = new AnyDataWrapper<podio::Frame>(podio::Frame());
ptr = std::make_unique<AnyDataWrapper<podio::Frame>>(podio::Frame());
}

const auto& frameCollections = ptr->getData().getAvailableCollections();
Expand Down Expand Up @@ -229,6 +232,10 @@ class Writer final : public Gaudi::Functional::Consumer<void(const EventContext&
error() << "Failed to unregister collection " << coll << endmsg;
return;
}
// We still have to delete the AnyDataWrapper to avoid a leak
auto storePtr = dynamic_cast<AnyDataWrapper<std::unique_ptr<podio::CollectionBase>>*>(storeCollection);
storePtr->getData().release();
delete storePtr;
}

for (auto& coll : m_collectionsToAdd) {
Expand All @@ -242,8 +249,11 @@ class Writer final : public Gaudi::Functional::Consumer<void(const EventContext&
error() << "Failed to unregister collection " << coll << endmsg;
return;
}
const auto collection = dynamic_cast<AnyDataWrapper<std::shared_ptr<podio::CollectionBase>>*>(storeCollection);
if (!collection) {
const auto collection = dynamic_cast<AnyDataWrapper<std::unique_ptr<podio::CollectionBase>>*>(storeCollection);
if (collection) {
ptr->getData().put(std::move(collection->getData()), coll);
delete collection;
} else {
// Check the case when the data has been produced using the old DataHandle
const auto old_collection = dynamic_cast<DataWrapperBase*>(storeCollection);
if (!old_collection) {
Expand All @@ -254,10 +264,6 @@ class Writer final : public Gaudi::Functional::Consumer<void(const EventContext&
const_cast<podio::CollectionBase*>(old_collection->collectionBase()));
ptr->getData().put(std::move(uptr), coll);
}

} else {
std::unique_ptr<podio::CollectionBase> uptr(collection->getData().get());
ptr->getData().put(std::move(uptr), coll);
}
}

Expand Down
13 changes: 6 additions & 7 deletions k4FWCore/include/k4FWCore/Consumer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
#ifndef FWCORE_CONSUMER_H
#define FWCORE_CONSUMER_H

#include <GaudiKernel/FunctionalFilterDecision.h>
#include "Gaudi/Functional/details.h"
#include "Gaudi/Functional/utilities.h"
#include "GaudiKernel/FunctionalFilterDecision.h"

// #include "GaudiKernel/CommonMessaging.h"

Expand All @@ -46,11 +46,10 @@ namespace k4FWCore {
static_assert(((std::is_base_of_v<podio::CollectionBase, In> || isVectorLike_v<In>)&&...),
"Consumer input types must be EDM4hep collections or vectors of collection pointers");

template <typename T>
using InputHandle_t = Gaudi::Functional::details::InputHandle_t<Traits_, std::remove_pointer_t<T>>;
template <typename T> using InputHandle_t = Gaudi::Functional::details::InputHandle_t<Traits_, T>;

std::tuple<std::vector<InputHandle_t<typename transformType<In>::type>>...> m_inputs;
std::array<Gaudi::Property<std::vector<DataObjID>>, sizeof...(In)> m_inputLocations{};
std::tuple<std::vector<InputHandle_t<typename EventStoreType<In>::type>>...> m_inputs;
std::array<Gaudi::Property<std::vector<DataObjID>>, sizeof...(In)> m_inputLocations{};

using base_class = Gaudi::Functional::details::DataHandleMixin<std::tuple<>, std::tuple<>, Traits_>;

Expand All @@ -66,9 +65,9 @@ namespace k4FWCore {
m_inputLocations{Gaudi::Property<std::vector<DataObjID>>{
this, std::get<I>(inputs).first, to_DataObjID(std::get<I>(inputs).second),
[this](Gaudi::Details::PropertyBase&) {
std::vector<InputHandle_t<typename transformType<In>::type>> handles;
std::vector<InputHandle_t<EventStoreType_t>> handles;
for (auto& value : this->m_inputLocations[I].value()) {
auto handle = InputHandle_t<typename transformType<In>::type>(value, this);
auto handle = InputHandle_t<EventStoreType_t>(value, this);
handles.push_back(std::move(handle));
}
std::get<I>(m_inputs) = std::move(handles);
Expand Down
6 changes: 2 additions & 4 deletions k4FWCore/include/k4FWCore/DataHandle.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@

#include "edm4hep/Constants.h"

#include "TTree.h"

#include <GaudiKernel/AnyDataWrapper.h>
#include <type_traits>

Expand Down Expand Up @@ -133,9 +131,9 @@ template <typename T> const T* DataHandle<T>::get() {
return reinterpret_cast<const T*>(tmp->collectionBase());
} else {
// When a functional has pushed a std::shared_ptr<podio::CollectionBase> into the store
auto ptr = static_cast<AnyDataWrapper<std::shared_ptr<podio::CollectionBase>>*>(dataObjectp)->getData();
auto ptr = static_cast<AnyDataWrapper<std::unique_ptr<podio::CollectionBase>>*>(dataObjectp);
if (ptr) {
return reinterpret_cast<const T*>(ptr.get());
return static_cast<const T*>(ptr->getData().get());
}
std::string errorMsg("The type provided for " + DataObjectHandle<DataWrapper<T>>::pythonRepr() +
" is different from the one of the object in the store.");
Expand Down
4 changes: 2 additions & 2 deletions k4FWCore/include/k4FWCore/DataWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ template <class T> class GAUDI_API DataWrapper : public DataWrapperBase {
template <class T2> friend class DataHandle;

public:
DataWrapper() : m_data(nullptr){};
DataWrapper() : m_data(nullptr) {}
DataWrapper(T&& coll) {
m_data = new T(std::move(coll));
is_owner = true;
}
DataWrapper(std::unique_ptr<T> uptr) : m_data(uptr.get()) {
uptr.release();
is_owner = false;
};
}
virtual ~DataWrapper();

const T* getData() const { return m_data; }
Expand Down
Loading

0 comments on commit 692b979

Please sign in to comment.