Skip to content

Commit

Permalink
Merge pull request #68 from zwicker-group/model_run
Browse files Browse the repository at this point in the history
Deprecated `.data` attribute of `Result` class
  • Loading branch information
david-zwicker authored Apr 13, 2024
2 parents 4d6406e + c52232e commit cae2932
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 49 deletions.
2 changes: 1 addition & 1 deletion examples/io_hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ def number_range(start: float = 1, length: int = 3):

# write result from file
read = Result.from_file("test.hdf")
print(read.parameters, "–– start + [0..length-1] =", read.data)
print(read.parameters, "–– start + [0..length-1] =", read.result)
2 changes: 1 addition & 1 deletion examples/io_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ def multiply(a: float = 1, b: float = 2):
result.to_file("test.json")

read = Result.from_file("test.json")
print(read.parameters, "–– a * b =", read.data)
print(read.parameters, "–– a * b =", read.result)
2 changes: 1 addition & 1 deletion examples/io_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ def multiply(a: float = 1, b: float = 2):
result.to_file("test.yaml")

read = Result.from_file("test.yaml")
print(read.parameters, "–– a * b =", read.data)
print(read.parameters, "–– a * b =", read.result)
2 changes: 1 addition & 1 deletion examples/model_storage_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ def multiply(a, b=2, storage=None):
# read the file and check whether all the data is there
with open_storage(fp.name) as storage:
print("Stored data:", storage["data"])
print("Model result:", storage["result"].data)
print("Model result:", storage["result"].result)
2 changes: 1 addition & 1 deletion examples/script_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ def multiply(a: float = 1, b: float = 2):

if __name__ == "__main__":
result = run_function_with_cmd_args(multiply)
print(result.data)
print(result.result)
43 changes: 36 additions & 7 deletions modelrunner/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,46 @@ class Result:
"""int: number indicating the version of the file format"""

def __init__(
self, model: ModelBase, result: Any, info: dict[str, Any] | None = None
self,
model: ModelBase,
result: Any,
*,
storage: StorageID = None,
info: dict[str, Any] | None = None,
):
"""
Args:
model (:class:`ModelBase`):
The model from which the result was obtained
result:
The actual result
storage:
A storage containing additional data from the model run
info (dict):
Additional information for this result
"""
if not isinstance(model, ModelBase):
raise TypeError("The model should be of type `ModelBase`")
self.model = model
self.result = result
self.model = model
self.storage = storage
self.info: Attrs = {} if info is None else info

@property
def data(self):
"""direct access to the underlying state data"""
# deprecated on 2024-04-13
warnings.warn("`.data` attribute was renamed to `.result`", DeprecationWarning)
return self.result

@classmethod
def from_data(
cls,
model_data: dict[str, Any],
result,
*,
model: ModelBase | None = None,
storage: StorageID = None,
info: dict[str, Any] | None = None,
) -> Result:
"""create result from data
Expand All @@ -86,6 +98,8 @@ def from_data(
The actual result data
model (:class:`ModelBase`):
The model from which the result was obtained
storage:
A storage containing additional data from the model run
info (dict):
Additional information for this result
Expand All @@ -103,7 +117,7 @@ def from_data(
model.name = model_data.get("name")
model.description = model_data.get("description")

return cls(model, result, info)
return cls(model, result, storage=storage, info=info)

@property
def parameters(self) -> dict[str, Any]:
Expand All @@ -119,13 +133,20 @@ def from_file(
):
"""load object from a file
This function loads the results from a hierachical storage. It also attempts to
read information about the model that was used to create this result and
additional data that might have been stored in a
:attr:`~modelrunner.results.Result.storage` while the model was running.
Args:
store (str or :class:`zarr.Store`):
Path or instance describing the storage, which is either a file path or
a :class:`zarr.Storage`.
key (str):
Name of the node in which the data was stored. This applies to some
hierarchical storage formats.
loc:
The location where the result is stored in the storage. This should
rarely be modified.
model (:class:`~modelrunner.model.ModelBase`):
The model which lead to this result
"""
if isinstance(storage, (str, Path)):
# check whether the file was written with an old format version
Expand All @@ -141,10 +162,15 @@ def from_file(
format_version = attrs.pop("format_version", None)
if format_version == cls._format_version:
# current version of storing results
if "data" in storage_obj:
data_storage = open_storage(storage, loc="data", mode="read")
else:
data_storage = None
return cls.from_data(
model_data=attrs.get("model", {}),
result=storage_obj.read_item(loc, use_class=False),
model=model,
storage=data_storage,
info=attrs.pop("info", {}), # load additional info,
)

Expand All @@ -154,7 +180,10 @@ def from_file(
def to_file(
self, storage: StorageID, loc: str = "result", *, mode: ModeType = "insert"
) -> None:
"""write this object to a file
"""write the results to a file
Note that this does only write the actual `results` but omits additional data
that might have been stored in a storage that is associated with the results.
Args:
storage (:class:`StorageBase` or :class:`StorageGroup`):
Expand Down
4 changes: 2 additions & 2 deletions tests/compatibility/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ def test_reading_compatibility(path):
try:
data = pickle.load(fp)
except ModuleNotFoundError:
assert result.data is not None # just test whether something was loaded
assert result.result is not None # just test whether something was loaded
else:
assert_data_equals(result.data, data, fuzzy=True)
assert_data_equals(result.result, data, fuzzy=True)
8 changes: 4 additions & 4 deletions tests/run/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def run(**p):
)
return Result.from_file(output)

assert run().data["a"] == 1
assert run(a=2).data["a"] == 2
assert run(b=[1, 2, 3]).data["b"] == [1, 2, 3]
assert run().result["a"] == 1
assert run(a=2).result["a"] == 2
assert run(b=[1, 2, 3]).result["b"] == [1, 2, 3]
std = capsys.readouterr()
assert std.out == std.err == ""

Expand All @@ -59,7 +59,7 @@ def test_submit_job_stdout(tmp_path, method):

assert outs == "3.0\n"
assert errs == ""
assert Result.from_file(output).data is None
assert Result.from_file(output).result is None


def test_submit_job_no_output():
Expand Down
52 changes: 26 additions & 26 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

import pytest

from modelrunner import Result, open_storage
from modelrunner.model import ModelBase, make_model, make_model_class, run_script
from modelrunner.parameters import (
DeprecatedParameter,
HideParameter,
NoValue,
Parameter,
)
from modelrunner.storage import open_storage

PACKAGEPATH = Path(__file__).parents[2].resolve()
SCRIPT_PATH = Path(__file__).parent / "scripts"
Expand All @@ -24,7 +24,7 @@
def run(script, *args):
"""run a script (with potential arguments) and collect stdout"""
result = run_script(SCRIPT_PATH / script, args)
return result.data
return result.result


def test_empty_script():
Expand Down Expand Up @@ -137,7 +137,7 @@ def __call__(self):
return self.parameters["a"] + self.parameters["b"]

assert A({"a": 3})() == 5
assert A.run_from_command_line(["--a", "3"]).data == 5
assert A.run_from_command_line(["--a", "3"]).result == 5
with pytest.raises(SystemExit):
A.run_from_command_line([])

Expand All @@ -158,7 +158,7 @@ def __call__(self):
assert A({"a": 3})() == 5
with pytest.raises(ValueError):
A({"a": 4})
assert A.run_from_command_line(["--a", "3"]).data == 5
assert A.run_from_command_line(["--a", "3"]).result == 5
with pytest.raises(SystemExit):
A.run_from_command_line(["--a", "4"])

Expand All @@ -175,7 +175,7 @@ def model1(a=2):
assert model1() == 4
assert model1(3) == 9
assert model1(a=4) == 16
assert model1.get_result().data == 4
assert model1.get_result().result == 4

@make_model
def model2(a, b=2):
Expand Down Expand Up @@ -203,7 +203,7 @@ def model_func(a=2):

assert model()() == 4
assert model({"a": 3})() == 9
assert model({"a": 4}).get_result().data == 16
assert model({"a": 4}).get_result().result == 16


def test_make_model_class_literal_args():
Expand All @@ -220,7 +220,7 @@ def model_func(a: Literal["a", 2] = 2):
with pytest.raises(ValueError):
model({"a": 3})

assert model.run_from_command_line(["--a", "a"]).data == "aa"
assert model.run_from_command_line(["--a", "a"]).result == "aa"
with pytest.raises(SystemExit):
model.run_from_command_line(["--a", "b"])

Expand All @@ -234,22 +234,22 @@ def parse_bool_0(flag: bool):

with pytest.raises(SystemExit):
parse_bool_0.run_from_command_line()
assert parse_bool_0.run_from_command_line(["--flag"]).data
assert not parse_bool_0.run_from_command_line(["--no-flag"]).data
assert parse_bool_0.run_from_command_line(["--flag"]).result
assert not parse_bool_0.run_from_command_line(["--no-flag"]).result

@make_model
def parse_bool_1(flag: bool = False):
return flag

assert not parse_bool_1.run_from_command_line().data
assert parse_bool_1.run_from_command_line(["--flag"]).data
assert not parse_bool_1.run_from_command_line().result
assert parse_bool_1.run_from_command_line(["--flag"]).result

@make_model
def parse_bool_2(flag: bool = True):
return flag

assert parse_bool_2.run_from_command_line().data
assert not parse_bool_2.run_from_command_line(["--no-flag"]).data
assert parse_bool_2.run_from_command_line().result
assert not parse_bool_2.run_from_command_line(["--no-flag"]).result


def test_argparse_list_arguments():
Expand All @@ -261,18 +261,18 @@ def parse_list_0(flag: list):

with pytest.raises(TypeError):
assert parse_list_0.run_from_command_line()
assert parse_list_0.run_from_command_line(["--flag"]).data == []
assert parse_list_0.run_from_command_line(["--flag", "0"]).data == ["0"]
assert parse_list_0.run_from_command_line(["--flag", "0", "1"]).data == ["0", "1"]
assert parse_list_0.run_from_command_line(["--flag"]).result == []
assert parse_list_0.run_from_command_line(["--flag", "0"]).result == ["0"]
assert parse_list_0.run_from_command_line(["--flag", "0", "1"]).result == ["0", "1"]

@make_model
def parse_list_1(flag: list = [0, 1]):
return flag

assert parse_list_1.run_from_command_line().data == [0, 1]
assert parse_list_1.run_from_command_line(["--flag"]).data == []
assert parse_list_1.run_from_command_line(["--flag", "0"]).data == ["0"]
assert parse_list_1.run_from_command_line(["--flag", "0", "1"]).data == ["0", "1"]
assert parse_list_1.run_from_command_line().result == [0, 1]
assert parse_list_1.run_from_command_line(["--flag"]).result == []
assert parse_list_1.run_from_command_line(["--flag", "0"]).result == ["0"]
assert parse_list_1.run_from_command_line(["--flag", "0", "1"]).result == ["0", "1"]


def test_model_class_inheritence():
Expand All @@ -296,7 +296,7 @@ def __call__(self):

assert A().parameters == {"a": 1, "b": 2, "c": 3}
assert A()() == 4
assert A.run_from_command_line(["--a", "2"]).data == 5
assert A.run_from_command_line(["--a", "2"]).result == 5
with pytest.raises(SystemExit):
A.run_from_command_line(["--b", "2"])

Expand All @@ -306,8 +306,8 @@ def __call__(self):
B.run_from_command_line(["--a", "2"])
with pytest.raises(SystemExit):
B.run_from_command_line(["--b", "2"])
assert B.run_from_command_line(["--c", "2"]).data == 8
assert B.run_from_command_line(["--d", "6"]).data == 11
assert B.run_from_command_line(["--c", "2"]).result == 8
assert B.run_from_command_line(["--d", "6"]).result == 11


def test_model_output(tmp_path):
Expand Down Expand Up @@ -350,6 +350,6 @@ def model_with_output(storage, a=3):
m = model_with_output(output=path)
m.write_result()

with open_storage(path) as storage:
assert storage["data/saved"] == {"A": "B"}
assert storage["result"].data == 5
res = Result.from_file(path)
assert res.storage["saved"] == {"A": "B"}
assert res.result == 5
10 changes: 5 additions & 5 deletions tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_result_serialization(ext, tmp_path):
# read data
read = Result.from_file(path)
assert read.model.name == "model"
np.testing.assert_equal(read.data, result.data)
np.testing.assert_equal(read.result, result.result)


@pytest.mark.skipif(not module_available("pde"), reason="requires `pde` module")
Expand All @@ -56,13 +56,13 @@ def test_pde_field_storage(ext, tmp_path):

# read data
read = Result.from_file(path)
np.testing.assert_equal(read.data, result.data)
np.testing.assert_equal(read.result, result.result)


@pytest.mark.skipif(not module_available("pde"), reason="requires `pde` module")
@pytest.mark.parametrize("ext", STORAGE_EXT)
def test_pde_trajectory_storage(ext, tmp_path):
"""test writing pde trajectories"""
def test_pde_trajectory_storage_manual(ext, tmp_path):
"""test writing pde trajectories manually"""
import pde

# create the result
Expand All @@ -79,7 +79,7 @@ def test_pde_trajectory_storage(ext, tmp_path):

# read data
read = Result.from_file(path)
assert_data_equals(read.data, result.data)
assert_data_equals(read.result, result.result)


def test_result_collections():
Expand Down

0 comments on commit cae2932

Please sign in to comment.