Skip to content

Commit

Permalink
Merge pull request #47 from zwicker-group/copy
Browse files Browse the repository at this point in the history
`copy.copy` of states now also copies state data
  • Loading branch information
david-zwicker authored Aug 28, 2023
2 parents 5cdb95d + 5564460 commit ba5a5a3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
12 changes: 12 additions & 0 deletions modelrunner/state/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,18 @@ def copy(self: TState, data=None, *, clean: bool = True) -> TState:
obj._state_data = data
return obj

def __copy__(self: TState) -> TState:
"""create a shallow copy of the state using :meth:`copy.copy`
This method inserts references into the new state to the objects found in the
original state. The only exception to this rule is the `data` attribute, which
will actually be copied.
Returns:
A copy of the current state object
"""
return self.copy(clean=False)

def _state_write_zarr_attributes(
self, element: zarrElement, attrs: Optional[Dict[str, Any]] = None
) -> zarrElement:
Expand Down
9 changes: 6 additions & 3 deletions tests/state/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@ def test_state_basic(state):
assert state.__class__.__name__ in StateBase._state_classes

s2 = state.copy(clean=False)
assert state is not s2
assert state == s2
assert state is not s2
assert state._state_data is not s2._state_data

s3 = state.copy(clean=True)
assert state is not s3
assert state == s3
assert state is not s3
assert state._state_data is not s3._state_data

s4 = copy.copy(state)
assert state is not s4
assert state == s4
assert state is not s4
assert state._state_data is not s4._state_data


@pytest.mark.parametrize("state", get_states(add_derived=False))
Expand Down

0 comments on commit ba5a5a3

Please sign in to comment.