Skip to content

Commit

Permalink
Merge pull request #48 from zwicker-group/copy
Browse files Browse the repository at this point in the history
Fixed bug with copying DictState
  • Loading branch information
david-zwicker authored Aug 29, 2023
2 parents 2b0a12a + c572ea3 commit e8e3fae
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 3 deletions.
6 changes: 3 additions & 3 deletions modelrunner/state/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,12 @@ def copy(self: TState, method: str, data=None) -> TState:
in the :attr:`_state_data`, which typically is aliased by :attr:`data`).
Args:
data:
Data to be used instead of the one in the current state. This data is
used as is and not copied!
method (str):
Determines whether a `clean`, `shallow`, or `data` copy is performed.
See description above for details.
data:
Data to be used instead of the one in the current state. This data is
used as is and not copied!
Returns:
A copy of the current state object
Expand Down
32 changes: 32 additions & 0 deletions modelrunner/state/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,38 @@ def from_data(cls, attributes: Dict[str, Any], data=None):
data = {k: v for k, v in zip(attributes.pop("__keys__"), data)}
return super().from_data(attributes, data)

def copy(self, method: str, data=None):
"""create a copy of the state
Args:
method (str):
Determines whether a `clean`, `shallow`, or `data` copy is performed.
See :meth:`~modelrunner.state.base.StateBase.copy` for details.
data:
Data to be used instead of the one in the current state. This data is
used as is and not copied!
Returns:
A copy of the current state object
"""
if method == "data":
# This special copy mode needs to be implemented in a very special way for
# `DictState` since only the data needs to be deep-copied, while all other
# attributes shall receive shallow copies. This particularly also needs to
# hold for the substates stored in `_state_data`.
obj = self.__class__.__new__(self.__class__)
obj.__dict__ = self.__dict__.copy()
if data is None:
obj._state_data = {
k: v.copy(method="data") for k, v in self._state_data.items()
}
else:
obj._state_data = data

else:
obj = super().copy(method=method, data=data)
return obj

def __len__(self) -> int:
return len(self._state_data)

Expand Down
39 changes: 39 additions & 0 deletions tests/state/test_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
.. codeauthor:: David Zwicker <[email protected]>
"""

import numpy as np
import pytest

from modelrunner.state import ArrayState, DictState, ObjectState


@pytest.mark.parametrize("method", ["shallow", "data", "clean"])
def test_dict_state_copy(method):
"""test basic properties of states"""
# define dict state classes
arr_state = ArrayState(np.arange(5))
arr_state._extra = "array"
obj_state = ObjectState({"list": [1, 2], "bool": True})
state = DictState({"a": arr_state, "o": obj_state})
state._extra = "state"

# copy everything by copying __dict__
s_c = state.copy(method)
assert state == s_c
assert state is not s_c

if method == "clean":
assert not hasattr(s_c, "_extra")
assert not hasattr(s_c["a"], "_extra")
else:
assert s_c._extra == "state"
assert s_c["a"]._extra == "array"

if method == "shallow":
assert state._state_data is s_c._state_data
assert state["a"] is s_c["a"]
else:
assert state._state_data is not s_c._state_data
assert state["a"] is not s_c["a"]
assert state["a"].data is not s_c["a"].data
File renamed without changes.

0 comments on commit e8e3fae

Please sign in to comment.