Skip to content

Commit

Permalink
Add python bindings for writing Frames (#447)
Browse files Browse the repository at this point in the history
* Add put method to Frame python bindings

* Add a test that writes a file via the python bindings and ROOT

* Add a c++ test that reads content written in python

* Add a put_parameter method to the python Frame
  • Loading branch information
tmadlener authored Jul 25, 2023
1 parent 25e5e3a commit 65b58ea
Show file tree
Hide file tree
Showing 13 changed files with 411 additions and 14 deletions.
25 changes: 25 additions & 0 deletions python/podio/base_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env python3
"""Python module for defining the basic writer interface that is used by the
backend specific bindings"""


class BaseWriterMixin:
"""Mixin class that defines the base interface of the writers.
The backend specific writers inherit from here and have to initialize the
following members:
- _writer: The actual writer that is able to write frames
"""

def write_frame(self, frame, category, collections=None):
"""Write the given frame under the passed category, optionally limiting the
collections that are written.
Args:
frame (podio.frame.Frame): The Frame to write
category (str): The category name
collections (optional, default=None): The subset of collections to
write. If None, all collections are written
"""
# pylint: disable-next=protected-access
self._writer.writeFrame(frame._frame, category, collections or frame.collections)
114 changes: 105 additions & 9 deletions python/podio/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,38 @@ def _determine_cpp_type(idx_and_type):
SUPPORTED_PARAMETER_TYPES = _determine_supported_parameter_types()


def _get_cpp_vector_types(type_str):
"""Get the possible std::vector<cpp_type> from the passed py_type string."""
# Gather a list of all types that match the type_str (c++ or python)
def _get_cpp_types(type_str):
"""Get all possible c++ types from the passed py_type string."""
types = list(filter(lambda t: type_str in t, SUPPORTED_PARAMETER_TYPES))
if not types:
raise ValueError(f'{type_str} cannot be mapped to a valid parameter type')

return types


def _get_cpp_vector_types(type_str):
"""Get the possible std::vector<cpp_type> from the passed py_type string."""
# Gather a list of all types that match the type_str (c++ or python)
types = _get_cpp_types(type_str)
return [f'std::vector<{t}>' for t in map(lambda x: x[0], types)]


def _is_collection_base(thing):
"""Check whether the passed thing is a podio::CollectionBase
Args:
thing (any): any object
Returns:
bool: True if thing is a base of podio::CollectionBase, False otherwise
"""
# Make sure to only instantiate the template with things that cppyy
# understands
if "cppyy" in repr(thing):
return cppyy.gbl.std.is_base_of[cppyy.gbl.podio.CollectionBase, type(thing)].value
return False


class Frame:
"""Frame class that serves as a container of collection and meta data."""

Expand All @@ -78,17 +100,16 @@ def __init__(self, data=None):
else:
self._frame = podio.Frame()

self._collections = tuple(str(s) for s in self._frame.getAvailableCollections())
self._param_key_types = self._init_param_keys()
self._param_key_types = self._get_param_keys_types()

@property
def collections(self):
"""Get the available collection (names) from this Frame.
"""Get the currently available collection (names) from this Frame.
Returns:
tuple(str): The names of the available collections from this Frame.
"""
return self._collections
return tuple(str(s) for s in self._frame.getAvailableCollections())

def get(self, name):
"""Get a collection from the Frame by name.
Expand All @@ -107,9 +128,32 @@ def get(self, name):
raise KeyError(f"Collection '{name}' is not available")
return collection

def put(self, collection, name):
"""Put the collection into the frame
The passed collectoin is "moved" into the Frame, i.e. it cannot be used any
longer after a call to this function. This also means that only objects that
were in the collection at the time of calling this function will be
available afterwards.
Args:
collection (podio.CollectionBase): The collection to put into the Frame
name (str): The name of the collection
Returns:
podio.CollectionBase: The reference to the collection that has been put
into the Frame. NOTE: That mutating this collection is not allowed.
Raises:
ValueError: If collection is not actually a podio.CollectionBase
"""
if not _is_collection_base(collection):
raise ValueError("Can only put podio collections into a Frame")
return self._frame.put(cppyy.gbl.std.move(collection), name)

@property
def parameters(self):
"""Get the available parameter names from this Frame.
"""Get the currently available parameter names from this Frame.
Returns:
tuple (str): The names of the available parameters from this Frame.
Expand Down Expand Up @@ -163,6 +207,58 @@ def _get_param_value(par_type, name):

return _get_param_value(vec_types[0], name)

def put_parameter(self, key, value, as_type=None):
"""Put a parameter into the Frame.
Puts a parameter into the Frame after doing some (incomplete) type checks.
If a list is passed the parameter type is determined from looking at the
first element of the list only. Additionally, since python doesn't
differentiate between floats and doubles, floats will always be stored as
doubles by default, use the as_type argument to change this if necessary.
Args:
key (str): The name of the parameter
value (int, float, str or list of these): The parameter value
as_type (str, optional): Explicitly specify the type that should be used
to put the parameter into the Frame. Python types (e.g. "str") will
be converted to c++ types. This will override any automatic type
deduction that happens otherwise. Note that this will be taken at
pretty much face-value and there are only limited checks for this.
Raises:
ValueError: If a non-supported parameter type is passed
"""
# For lists we determine the c++ vector type and use that to call the
# correct template overload explicitly
if isinstance(value, (list, tuple)):
type_name = as_type or type(value[0]).__name__
vec_types = _get_cpp_vector_types(type_name)
if len(vec_types) == 0:
raise ValueError(f"Cannot put a parameter of type {type_name} into a Frame")

par_type = vec_types[0]
if isinstance(value[0], float):
# Always store floats as doubles from the python side
par_type = par_type.replace("float", "double")

self._frame.putParameter[par_type](key, value)
else:
if as_type is not None:
cpp_types = _get_cpp_types(as_type)
if len(cpp_types) == 0:
raise ValueError(f"Cannot put a parameter of type {as_type} into a Frame")
self._frame.putParameter[cpp_types[0]](key, value)

# If we have a single integer, a std::string overload kicks in with higher
# priority than the template for some reason. So we explicitly select the
# correct template here
elif isinstance(value, int):
self._frame.putParameter["int"](key, value)
else:
self._frame.putParameter(key, value)

self._param_key_types = self._get_param_keys_types() # refresh the cache

def get_parameters(self):
"""Get the complete podio::GenericParameters object stored in this Frame.
Expand Down Expand Up @@ -200,7 +296,7 @@ def get_param_info(self, name):

return par_infos

def _init_param_keys(self):
def _get_param_keys_types(self):
"""Initialize the param keys dict for easier lookup of the available parameters.
Returns:
Expand Down
14 changes: 12 additions & 2 deletions python/podio/root_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from ROOT import podio # noqa: E402 # pylint: disable=wrong-import-position

from podio.base_reader import BaseReaderMixin # pylint: disable=wrong-import-position

Writer = podio.ROOTFrameWriter
from podio.base_writer import BaseWriterMixin # pylint: disable=wrong-import-position


class Reader(BaseReaderMixin):
Expand Down Expand Up @@ -49,3 +48,14 @@ def __init__(self, filenames):
self._is_legacy = True

super().__init__()


class Writer(BaseWriterMixin):
"""Writer class for writing podio root files"""
def __init__(self, filename):
"""Create a writer for writing files
Args:
filename (str): The name of the output file
"""
self._writer = podio.ROOTFrameWriter(filename)
14 changes: 12 additions & 2 deletions python/podio/sio_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from ROOT import podio # noqa: 402 # pylint: disable=wrong-import-position

from podio.base_reader import BaseReaderMixin # pylint: disable=wrong-import-position

Writer = podio.SIOFrameWriter
from podio.base_writer import BaseWriterMixin # pylint: disable=wrong-import-position


class Reader(BaseReaderMixin):
Expand Down Expand Up @@ -46,3 +45,14 @@ def __init__(self, filename):
self._is_legacy = True

super().__init__()


class Writer(BaseWriterMixin):
"""Writer class for writing podio root files"""
def __init__(self, filename):
"""Create a writer for writing files
Args:
filename (str): The name of the output file
"""
self._writer = podio.SIOFrameWriter(filename)
59 changes: 59 additions & 0 deletions python/podio/test_Frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# using root_io as that should always be present regardless of which backends are built
from podio.root_io import Reader

from podio.test_utils import ExampleHitCollection

# The expected collections in each frame
EXPECTED_COLL_NAMES = {
'arrays', 'WithVectorMember', 'info', 'fixedWidthInts', 'mcparticles',
Expand Down Expand Up @@ -34,6 +36,63 @@ def test_frame_invalid_access(self):
with self.assertRaises(KeyError):
_ = frame.get_parameter('NonExistantParameter')

with self.assertRaises(ValueError):
collection = [1, 2, 4]
_ = frame.put(collection, "invalid_collection_type")

def test_frame_put_collection(self):
"""Check that putting a collection works as expected"""
frame = Frame()
self.assertEqual(frame.collections, tuple())

hits = ExampleHitCollection()
hits.create()
hits2 = frame.put(hits, "hits_from_python")
self.assertEqual(frame.collections, tuple(["hits_from_python"]))
# The original collection is gone at this point, and ideally just leaves an
# empty shell
self.assertEqual(len(hits), 0)
# On the other hand the return value of put has the original content
self.assertEqual(len(hits2), 1)

def test_frame_put_parameters(self):
"""Check that putting a parameter works as expected"""
frame = Frame()
self.assertEqual(frame.parameters, tuple())

frame.put_parameter("a_string_param", "a string")
self.assertEqual(frame.parameters, tuple(["a_string_param"]))
self.assertEqual(frame.get_parameter("a_string_param"), "a string")

frame.put_parameter("float_param", 3.14)
self.assertEqual(frame.get_parameter("float_param"), 3.14)

frame.put_parameter("int", 42)
self.assertEqual(frame.get_parameter("int"), 42)

frame.put_parameter("string_vec", ["a", "b", "cd"])
str_vec = frame.get_parameter("string_vec")
self.assertEqual(len(str_vec), 3)
self.assertEqual(str_vec, ["a", "b", "cd"])

frame.put_parameter("more_ints", [1, 2345])
int_vec = frame.get_parameter("more_ints")
self.assertEqual(len(int_vec), 2)
self.assertEqual(int_vec, [1, 2345])

frame.put_parameter("float_vec", [1.23, 4.56, 7.89])
vec = frame.get_parameter("float_vec", as_type="double")
self.assertEqual(len(vec), 3)
self.assertEqual(vec, [1.23, 4.56, 7.89])

frame.put_parameter("real_float_vec", [1.23, 4.56, 7.89], as_type="float")
f_vec = frame.get_parameter("real_float_vec", as_type="float")
self.assertEqual(len(f_vec), 3)
self.assertEqual(vec, [1.23, 4.56, 7.89])

frame.put_parameter("float_as_float", 3.14, as_type="float")
self.assertAlmostEqual(frame.get_parameter("float_as_float"), 3.14, places=5)


class FrameReadTest(unittest.TestCase):
"""Unit tests for the Frame python bindings for Frames read from file.
Expand Down
54 changes: 53 additions & 1 deletion python/podio/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,57 @@
"""Utilities for python unittests"""

import os
import ROOT
ROOT.gSystem.Load("libTestDataModelDict.so") # noqa: E402
from ROOT import ExampleHitCollection, ExampleClusterCollection # noqa: E402 # pylint: disable=wrong-import-position

SKIP_SIO_TESTS = os.environ.get('SKIP_SIO_TESTS', '1') == '1'
from podio.frame import Frame # pylint: disable=wrong-import-position


SKIP_SIO_TESTS = os.environ.get("SKIP_SIO_TESTS", "1") == "1"


def create_hit_collection():
"""Create a simple hit collection with two hits for testing"""
hits = ExampleHitCollection()
hits.create(0xBAD, 0.0, 0.0, 0.0, 23.0)
hits.create(0xCAFFEE, 1.0, 0.0, 0.0, 12.0)

return hits


def create_cluster_collection():
"""Create a simple cluster collection with two clusters"""
clusters = ExampleClusterCollection()
clu0 = clusters.create()
clu0.energy(3.14)
clu1 = clusters.create()
clu1.energy(1.23)

return clusters


def create_frame():
"""Create a frame with an ExampleHit and an ExampleCluster collection"""
frame = Frame()
hits = create_hit_collection()
frame.put(hits, "hits_from_python")
clusters = create_cluster_collection()
frame.put(clusters, "clusters_from_python")

frame.put_parameter("an_int", 42)
frame.put_parameter("some_floats", [1.23, 7.89, 3.14])
frame.put_parameter("greetings", ["from", "python"])
frame.put_parameter("real_float", 3.14, as_type="float")
frame.put_parameter("more_real_floats", [1.23, 4.56, 7.89], as_type="float")

return frame


def write_file(writer_type, filename):
"""Write a file using the given Writer type and put one Frame into it under
the events category
"""
writer = writer_type(filename)
event = create_frame()
writer.write_frame(event, "events")
4 changes: 4 additions & 0 deletions tests/CTestCustom.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ if ((NOT "@FORCE_RUN_ALL_TESTS@" STREQUAL "ON") AND (NOT "@USE_SANITIZER@" STREQ
read-legacy-files-root_v00-13
read_frame_legacy_root
read_frame_root_multiple
write_python_frame_root
read_python_frame_root

write_frame_root
read_frame_root
Expand All @@ -35,6 +37,8 @@ if ((NOT "@FORCE_RUN_ALL_TESTS@" STREQUAL "ON") AND (NOT "@USE_SANITIZER@" STREQ
write_frame_sio
read_frame_sio
read_frame_legacy_sio
write_python_frame_sio
read_python_frame_sio

write_ascii

Expand Down
Loading

0 comments on commit 65b58ea

Please sign in to comment.