Skip to content

Commit

Permalink
Merge pull request #67 from zwicker-group/model_storage
Browse files Browse the repository at this point in the history
Change location of storage of models
  • Loading branch information
david-zwicker authored Apr 13, 2024
2 parents f37b227 + 9fef62d commit 4d6406e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
22 changes: 14 additions & 8 deletions modelrunner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ def __init__(
:meth:`~Parameterized.get_parameters` or displayed by calling
:meth:`~Parameterized.show_parameters`.
output (str):
Path where the output file will be written.
Path where the output file will be written. The output will be written
using :mod:`~modelrunner.storage` and might contain two groups: `result`
to which the final result of the model is written, and `data`, which
can contain extra information that is written using
:attr:`~ModelBase.storage`.
mode (str or :class:`~modelrunner.storage.access_modes.ModeType`):
The file mode with which the storage is accessed, which determines the
allowed operations. Common options are "read", "full", "append", and
Expand All @@ -87,7 +91,9 @@ def storage(self) -> StorageGroup:
if self.output is None:
raise RuntimeError("Output file needs to be specified")
self._storage = open_storage(self.output, mode=self.mode)
return self._storage
return self._storage.create_group("data")
else:
return self._storage.open_group("data")

def close(self) -> None:
"""close any opened storages"""
Expand Down Expand Up @@ -124,21 +130,21 @@ def write_result(self, result: Result | None = None) -> None:
result:
The result data. If omitted, the model is run to obtain results
"""
from .results import Result # @Reimport

if self.output is None:
raise RuntimeError("Output file needs to be specified")

if result is None:
result = self.get_result()
else:
from .results import Result # @Reimport

assert isinstance(result, Result)
elif not isinstance(result, Result):
raise TypeError(f"result has type {result.__class__} instead of `Result`")

if self._storage is not None:
# reuse the opened storage
result.to_file(self.storage)
result.to_file(self._storage, loc="result")
else:
result.to_file(self.output, mode=self.mode)
result.to_file(self.output, loc="result", mode=self.mode)

@classmethod
def _prepare_argparser(cls, name: str | None = None) -> argparse.ArgumentParser:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def __call__(self, b):
a.close()

with open_storage(tmp_path / "model.json") as storage:
assert storage["info"] == {"args": 5}
assert storage["data/info"] == {"args": 5}


@pytest.mark.parametrize("kwarg", [True, False])
Expand All @@ -351,5 +351,5 @@ def model_with_output(storage, a=3):
m.write_result()

with open_storage(path) as storage:
assert storage["saved"] == {"A": "B"}
assert storage["data/saved"] == {"A": "B"}
assert storage["result"].data == 5

0 comments on commit 4d6406e

Please sign in to comment.