-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #48 from zwicker-group/copy
Fixed bug with copying DictState
- Loading branch information
Showing
4 changed files
with
74 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
""" | ||
.. codeauthor:: David Zwicker <[email protected]> | ||
""" | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from modelrunner.state import ArrayState, DictState, ObjectState | ||
|
||
|
||
@pytest.mark.parametrize("method", ["shallow", "data", "clean"]) | ||
def test_dict_state_copy(method): | ||
"""test basic properties of states""" | ||
# define dict state classes | ||
arr_state = ArrayState(np.arange(5)) | ||
arr_state._extra = "array" | ||
obj_state = ObjectState({"list": [1, 2], "bool": True}) | ||
state = DictState({"a": arr_state, "o": obj_state}) | ||
state._extra = "state" | ||
|
||
# copy everything by copying __dict__ | ||
s_c = state.copy(method) | ||
assert state == s_c | ||
assert state is not s_c | ||
|
||
if method == "clean": | ||
assert not hasattr(s_c, "_extra") | ||
assert not hasattr(s_c["a"], "_extra") | ||
else: | ||
assert s_c._extra == "state" | ||
assert s_c["a"]._extra == "array" | ||
|
||
if method == "shallow": | ||
assert state._state_data is s_c._state_data | ||
assert state["a"] is s_c["a"] | ||
else: | ||
assert state._state_data is not s_c._state_data | ||
assert state["a"] is not s_c["a"] | ||
assert state["a"].data is not s_c["a"].data |
File renamed without changes.