Skip to content

Commit

Permalink
Catch case when windows do not span entire sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
nspope committed Sep 19, 2024
1 parent dcee409 commit 702385c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
5 changes: 3 additions & 2 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -9651,6 +9651,7 @@ tsk_treeseq_pair_coalescence_counts(const tsk_treeseq_t *self,
tsk_size_t num_windows, const double *windows, tsk_size_t num_bins,
const tsk_id_t *node_bin_map, tsk_flags_t options, double *result)
{
options |= TSK_REQUIRE_FULL_SPAN;
return tsk_treeseq_pair_coalescence_stat(self, num_sample_sets, sample_set_sizes,
sample_sets, num_set_indexes, set_indexes, num_windows, windows, num_bins,
node_bin_map, pair_coalescence_weights, num_bins, NULL, options, result);
Expand Down Expand Up @@ -9767,7 +9768,7 @@ tsk_treeseq_pair_coalescence_quantiles(const tsk_treeseq_t *self,
if (ret != 0) {
goto out;
}
options |= TSK_STAT_SPAN_NORMALISE | TSK_STAT_PAIR_NORMALISE;
options |= TSK_STAT_SPAN_NORMALISE | TSK_STAT_PAIR_NORMALISE | TSK_REQUIRE_FULL_SPAN;
ret = tsk_treeseq_pair_coalescence_stat(self, num_sample_sets, sample_set_sizes,
sample_sets, num_set_indexes, set_indexes, num_windows, windows, num_bins,
node_bin_map, pair_coalescence_quantiles, num_quantiles, params, options,
Expand Down Expand Up @@ -9885,7 +9886,7 @@ tsk_treeseq_pair_coalescence_rates(const tsk_treeseq_t *self, tsk_size_t num_sam
if (ret != 0) {
goto out;
}
options |= TSK_STAT_SPAN_NORMALISE | TSK_STAT_PAIR_NORMALISE;
options |= TSK_STAT_SPAN_NORMALISE | TSK_STAT_PAIR_NORMALISE | TSK_REQUIRE_FULL_SPAN;
ret = tsk_treeseq_pair_coalescence_stat(self, num_sample_sets, sample_set_sizes,
sample_sets, num_set_indexes, set_indexes, num_windows, windows,
num_time_windows, node_time_window, pair_coalescence_rates, num_time_windows,
Expand Down
19 changes: 12 additions & 7 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4402,8 +4402,8 @@ def test_c_tsk_err_node_out_of_bounds(self, bad_node):
def test_c_tsk_err_bad_windows(self):
ts = self.example_ts()
L = ts.get_sequence_length()
with pytest.raises(_tskit.LibraryError, match="BAD_WINDOWS"):
self.pair_coalescence_counts(ts, windows=[-1.0, L])
with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_WINDOWS"):
self.pair_coalescence_counts(ts, windows=[1.0, L])

def test_c_tsk_err_bad_node_bin_map(self):
ts = self.example_ts()
Expand Down Expand Up @@ -4526,6 +4526,11 @@ def test_c_tsk_err_unsorted_times(self):
with pytest.raises(_tskit.LibraryError, match="TSK_ERR_UNSORTED_TIMES"):
self.pair_coalescence_quantiles(ts, node_bin_map=node_bin_map)

def test_c_tsk_err_bad_windows(self):
ts = self.example_ts()
with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_WINDOWS"):
self.pair_coalescence_quantiles(ts, windows=[1.0, ts.get_sequence_length()])

@pytest.mark.parametrize("bad_ss_size", [-1, 1000])
def test_cpy_bad_sample_sets(self, bad_ss_size):
ts = self.example_ts()
Expand Down Expand Up @@ -4654,6 +4659,11 @@ def test_c_tsk_err_bad_sample_pair_times(self):
with pytest.raises(_tskit.LibraryError, match="TSK_ERR_BAD_SAMPLE_PAIR_TIMES"):
self.pair_coalescence_rates(ts, time_windows=np.array([-1.0, np.inf]))

def test_c_tsk_err_bad_windows(self):
ts = self.example_ts()
with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_WINDOWS"):
self.pair_coalescence_rates(ts, windows=[1.0, ts.get_sequence_length()])

@pytest.mark.parametrize("bad_ss_size", [-1, 1000])
def test_cpy_bad_sample_sets(self, bad_ss_size):
ts = self.example_ts()
Expand All @@ -4665,11 +4675,6 @@ def test_cpy_bad_sample_sets(self, bad_ss_size):
ts, sample_set_sizes=[bad_ss_size, ts.get_num_samples()]
)

def test_cpy_bad_windows(self):
ts = self.example_ts()
with pytest.raises(ValueError, match="at least 2"):
self.pair_coalescence_rates(ts, windows=[0.0])

@pytest.mark.parametrize("indexes", [[(0, 0, 0)], np.zeros((0, 2), dtype=np.int32)])
def test_cpy_bad_indexes(self, indexes):
ts = self.example_ts()
Expand Down

0 comments on commit 702385c

Please sign in to comment.