From c572ea319da55dde3b5c2c2d5dd335a8db93a013 Mon Sep 17 00:00:00 2001 From: David Zwicker Date: Mon, 28 Aug 2023 20:49:18 +0200 Subject: [PATCH] Fixed bug with copying DictState * Fixed some docstrings * Renamed a test file --- modelrunner/state/base.py | 6 +-- modelrunner/state/dict.py | 32 +++++++++++++++ tests/state/test_dict.py | 39 +++++++++++++++++++ .../state/{test_state.py => test_generic.py} | 0 4 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 tests/state/test_dict.py rename tests/state/{test_state.py => test_generic.py} (100%) diff --git a/modelrunner/state/base.py b/modelrunner/state/base.py index cf953fa..005481a 100644 --- a/modelrunner/state/base.py +++ b/modelrunner/state/base.py @@ -281,12 +281,12 @@ def copy(self: TState, method: str, data=None) -> TState: in the :attr:`_state_data`, which typically is aliased by :attr:`data`). Args: - data: - Data to be used instead of the one in the current state. This data is - used as is and not copied! method (str): Determines whether a `clean`, `shallow`, or `data` copy is performed. See description above for details. + data: + Data to be used instead of the one in the current state. This data is + used as is and not copied! Returns: A copy of the current state object diff --git a/modelrunner/state/dict.py b/modelrunner/state/dict.py index 45f6bf3..b8d7b9c 100644 --- a/modelrunner/state/dict.py +++ b/modelrunner/state/dict.py @@ -65,6 +65,38 @@ def from_data(cls, attributes: Dict[str, Any], data=None): data = {k: v for k, v in zip(attributes.pop("__keys__"), data)} return super().from_data(attributes, data) + def copy(self, method: str, data=None): + """create a copy of the state + + Args: + method (str): + Determines whether a `clean`, `shallow`, or `data` copy is performed. + See :meth:`~modelrunner.state.base.StateBase.copy` for details. + data: + Data to be used instead of the one in the current state. This data is + used as is and not copied! + + Returns: + A copy of the current state object + """ + if method == "data": + # This special copy mode needs to be implemented in a very special way for + # `DictState` since only the data needs to be deep-copied, while all other + # attributes shall receive shallow copies. This particularly also needs to + # hold for the substates stored in `_state_data`. + obj = self.__class__.__new__(self.__class__) + obj.__dict__ = self.__dict__.copy() + if data is None: + obj._state_data = { + k: v.copy(method="data") for k, v in self._state_data.items() + } + else: + obj._state_data = data + + else: + obj = super().copy(method=method, data=data) + return obj + def __len__(self) -> int: return len(self._state_data) diff --git a/tests/state/test_dict.py b/tests/state/test_dict.py new file mode 100644 index 0000000..41c5052 --- /dev/null +++ b/tests/state/test_dict.py @@ -0,0 +1,39 @@ +""" +.. codeauthor:: David Zwicker +""" + +import numpy as np +import pytest + +from modelrunner.state import ArrayState, DictState, ObjectState + + +@pytest.mark.parametrize("method", ["shallow", "data", "clean"]) +def test_dict_state_copy(method): + """test basic properties of states""" + # define dict state classes + arr_state = ArrayState(np.arange(5)) + arr_state._extra = "array" + obj_state = ObjectState({"list": [1, 2], "bool": True}) + state = DictState({"a": arr_state, "o": obj_state}) + state._extra = "state" + + # copy everything by copying __dict__ + s_c = state.copy(method) + assert state == s_c + assert state is not s_c + + if method == "clean": + assert not hasattr(s_c, "_extra") + assert not hasattr(s_c["a"], "_extra") + else: + assert s_c._extra == "state" + assert s_c["a"]._extra == "array" + + if method == "shallow": + assert state._state_data is s_c._state_data + assert state["a"] is s_c["a"] + else: + assert state._state_data is not s_c._state_data + assert state["a"] is not s_c["a"] + assert state["a"].data is not s_c["a"].data diff --git a/tests/state/test_state.py b/tests/state/test_generic.py similarity index 100% rename from tests/state/test_state.py rename to tests/state/test_generic.py