Skip to content

Commit

Permalink
Update tests to account for split/merge output in tdating
Browse files Browse the repository at this point in the history
  • Loading branch information
ritvje committed Jul 17, 2024
1 parent 880cc98 commit 25209c2
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 20 deletions.
65 changes: 56 additions & 9 deletions pysteps/tests/test_feature_tstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,29 @@
except ModuleNotFoundError:
pass

arg_names = ("source", "output_feat", "dry_input", "max_num_features")
arg_names = (
"source",
"output_feat",
"dry_input",
"max_num_features",
"output_split_merge",
)

arg_values = [
("mch", False, False, None),
("mch", False, False, 5),
("mch", True, False, None),
("mch", True, False, 5),
("mch", False, True, None),
("mch", False, True, 5),
("mch", False, False, None, False),
("mch", False, False, 5, False),
("mch", True, False, None, False),
("mch", True, False, 5, False),
("mch", False, True, None, False),
("mch", False, True, 5, False),
("mch", False, False, None, True),
]


@pytest.mark.parametrize(arg_names, arg_values)
def test_feature_tstorm_detection(source, output_feat, dry_input, max_num_features):
def test_feature_tstorm_detection(
source, output_feat, dry_input, max_num_features, output_split_merge
):
pytest.importorskip("skimage")
pytest.importorskip("pandas")

Expand All @@ -36,7 +45,11 @@ def test_feature_tstorm_detection(source, output_feat, dry_input, max_num_featur

time = "000"
output = detection(
input, time=time, output_feat=output_feat, max_num_features=max_num_features
input,
time=time,
output_feat=output_feat,
max_num_features=max_num_features,
output_splits_merges=output_split_merge,
)

if output_feat:
Expand All @@ -45,6 +58,40 @@ def test_feature_tstorm_detection(source, output_feat, dry_input, max_num_featur
assert output.shape[1] == 2
if max_num_features is not None:
assert output.shape[0] <= max_num_features
elif output_split_merge:
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[0], DataFrame)
assert isinstance(output[1], np.ndarray)
if max_num_features is not None:
assert output[0].shape[0] <= max_num_features
assert output[0].shape[1] == 15
assert list(output[0].columns) == [
"ID",
"time",
"x",
"y",
"cen_x",
"cen_y",
"max_ref",
"cont",
"area",
"splitted",
"split_IDs",
"merged",
"merged_IDs",
"results_from_split",
"will_merge",
]
assert (output[0].time == time).all()
assert output[1].ndim == 2
assert output[1].shape == input.shape
if not dry_input:
assert output[0].shape[0] > 0
assert sorted(list(output[0].ID)) == sorted(list(np.unique(output[1]))[1:])
else:
assert output[0].shape[0] == 0
assert output[1].sum() == 0
else:
assert isinstance(output, tuple)
assert len(output) == 2
Expand Down
32 changes: 21 additions & 11 deletions pysteps/tests/test_tracking_tdating.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,24 @@
from pysteps.utils import to_reflectivity
from pysteps.tests.helpers import get_precipitation_fields

arg_names = ("source", "dry_input")
arg_names = ("source", "dry_input", "output_splits_merges")

arg_values = [
("mch", False),
("mch", False),
("mch", True),
("mch", False, False),
("mch", False, False),
("mch", True, False),
("mch", False, True),
]

arg_names_multistep = ("source", "len_timesteps")
arg_names_multistep = ("source", "len_timesteps", "output_splits_merges")
arg_values_multistep = [
("mch", 6),
("mch", 6, False),
("mch", 6, True),
]


@pytest.mark.parametrize(arg_names_multistep, arg_values_multistep)
def test_tracking_tdating_dating_multistep(source, len_timesteps):
def test_tracking_tdating_dating_multistep(source, len_timesteps, output_splits_merges):
pytest.importorskip("skimage")

input_fields, metadata = get_precipitation_fields(
Expand All @@ -37,6 +39,7 @@ def test_tracking_tdating_dating_multistep(source, len_timesteps):
input_fields[0 : len_timesteps // 2],
timelist[0 : len_timesteps // 2],
mintrack=1,
output_splits_merges=output_splits_merges,
)
# Second half of timesteps
tracks_2, cells, _ = dating(
Expand All @@ -46,6 +49,7 @@ def test_tracking_tdating_dating_multistep(source, len_timesteps):
start=2,
cell_list=cells,
label_list=labels,
output_splits_merges=output_splits_merges,
)

# Since we are adding cells, number of tracks should increase
Expand All @@ -67,7 +71,7 @@ def test_tracking_tdating_dating_multistep(source, len_timesteps):


@pytest.mark.parametrize(arg_names, arg_values)
def test_tracking_tdating_dating(source, dry_input):
def test_tracking_tdating_dating(source, dry_input, output_splits_merges):
pytest.importorskip("skimage")
pandas = pytest.importorskip("pandas")

Expand All @@ -80,7 +84,13 @@ def test_tracking_tdating_dating(source, dry_input):

timelist = metadata["timestamps"]

output = dating(input, timelist, mintrack=1)
cell_column_length = 9
if output_splits_merges:
cell_column_length = 15

output = dating(
input, timelist, mintrack=1, output_splits_merges=output_splits_merges
)

# Check output format
assert isinstance(output, tuple)
Expand All @@ -92,12 +102,12 @@ def test_tracking_tdating_dating(source, dry_input):
assert len(output[2]) == input.shape[0]
assert isinstance(output[1][0], pandas.DataFrame)
assert isinstance(output[2][0], np.ndarray)
assert output[1][0].shape[1] == 9
assert output[1][0].shape[1] == cell_column_length
assert output[2][0].shape == input.shape[1:]
if not dry_input:
assert len(output[0]) > 0
assert isinstance(output[0][0], pandas.DataFrame)
assert output[0][0].shape[1] == 9
assert output[0][0].shape[1] == cell_column_length
else:
assert len(output[0]) == 0
assert output[1][0].shape[0] == 0
Expand Down

0 comments on commit 25209c2

Please sign in to comment.