From 5564460b37e92d20a548c2712b95fd777da6a3a0 Mon Sep 17 00:00:00 2001 From: David Zwicker Date: Mon, 28 Aug 2023 16:42:59 +0200 Subject: [PATCH] `copy.copy` of states now also copies state data --- modelrunner/state/base.py | 12 ++++++++++++ tests/state/test_state.py | 9 ++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/modelrunner/state/base.py b/modelrunner/state/base.py index b2390a5..9782512 100644 --- a/modelrunner/state/base.py +++ b/modelrunner/state/base.py @@ -305,6 +305,18 @@ def copy(self: TState, data=None, *, clean: bool = True) -> TState: obj._state_data = data return obj + def __copy__(self: TState) -> TState: + """create a shallow copy of the state using :meth:`copy.copy` + + This method inserts references into the new state to the objects found in the + original state. The only exception to this rule is the `data` attribute, which + will actually be copied. + + Returns: + A copy of the current state object + """ + return self.copy(clean=False) + def _state_write_zarr_attributes( self, element: zarrElement, attrs: Optional[Dict[str, Any]] = None ) -> zarrElement: diff --git a/tests/state/test_state.py b/tests/state/test_state.py index 1a8fa4c..a1c04c8 100644 --- a/tests/state/test_state.py +++ b/tests/state/test_state.py @@ -25,16 +25,19 @@ def test_state_basic(state): assert state.__class__.__name__ in StateBase._state_classes s2 = state.copy(clean=False) - assert state is not s2 assert state == s2 + assert state is not s2 + assert state._state_data is not s2._state_data s3 = state.copy(clean=True) - assert state is not s3 assert state == s3 + assert state is not s3 + assert state._state_data is not s3._state_data s4 = copy.copy(state) - assert state is not s4 assert state == s4 + assert state is not s4 + assert state._state_data is not s4._state_data @pytest.mark.parametrize("state", get_states(add_derived=False))