Skip to content

Commit

Permalink
Partial
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Oct 12, 2023
1 parent 8926781 commit 773b66c
Showing 1 changed file with 120 additions and 21 deletions.
141 changes: 120 additions & 21 deletions python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""
Python implementation of the Li and Stephens forwards and backwards algorithms.
"""
import io
import warnings

import lshmm as ls
Expand All @@ -37,7 +38,8 @@
MISSING = -1


# np.set_printoptions(linewidth=1000, precision=3)
# For debugging
np.set_printoptions(linewidth=1000, precision=3)


def check_alleles(alleles, m):
Expand Down Expand Up @@ -151,7 +153,7 @@ def node_values(self):
def print_state(self):
print("LsHMM state")
print("match_all_nodes =", self.match_all_nodes)
print("Tree =")
print("Tree = ", self.tree.index, self.tree.interval)
node_labels = {}
for u, value in self.node_values().items():
label = f"{u}"
Expand Down Expand Up @@ -434,11 +436,13 @@ def update_probabilities(self, site, haplotype_state):
def process_site(self, site, haplotype_state):
self.update_probabilities(site, haplotype_state)
# d1 = self.node_values()
print("PRE")
self.print_state()
self.compress()
# d2 = self.node_values()
# assert d1 == d2
# print("AFTER COMPRESS")
# self.print_state()
print("AFTER COMPRESS")
self.print_state()
s = self.compute_normalisation_factor()
for st in self.T:
assert st.tree_node != tskit.NULL
Expand Down Expand Up @@ -489,8 +493,13 @@ def run(self, h):
self.initialise(1 / n)
while self.tree.next():
self.update_tree()
if self.tree.index != 0:
print("AFTER UPDATE TREE")
self.print_state()
for site in self.tree.sites():
self.process_site(site, h[site.id])
print("BEFORE UPDATE TREE")
self.print_state()
return self.output

def compute_normalisation_factor(self):
Expand Down Expand Up @@ -1182,7 +1191,6 @@ def verify(self, ts):
self.assertAllClose(ll, ll_check)


# TODO add params to run the various checks
def check_viterbi(
ts,
h,
Expand Down Expand Up @@ -1212,10 +1220,10 @@ def check_viterbi(
cm = ls_viterbi_tree(
h, ts, rho=recombination, mu=mutation, match_all_nodes=match_all_nodes
)
cm.print_state()
path_tree = cm.traceback(match_all_nodes=match_all_nodes)
ll_tree = np.sum(np.log10(cm.normalisation_factor))
assert np.isscalar(ll_tree)
# print(cm)
# print("path tree = ", path_tree)

if compare_lshmm:
Expand Down Expand Up @@ -1437,8 +1445,8 @@ def test_match_sample(self, j):
ts = self.ts()
h = np.zeros(4)
h[j] = 1
# path = check_viterbi(ts, h)
# nt.assert_array_equal([j, j, j, j], path)
path = check_viterbi(ts, h)
nt.assert_array_equal([j, j, j, j], path)
cm = check_forward_matrix(ts, h)
check_backward_matrix(ts, h, cm)

Expand Down Expand Up @@ -1525,6 +1533,19 @@ def test_match_sample(self, u, h):
)


def validate_match_all_nodes(ts, h, expected_path):
path = check_viterbi(
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
)
nt.assert_array_equal(expected_path, path)
cm = check_forward_matrix(
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
)
bm = check_backward_matrix(
ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=False
)


class TestSingleBalancedTreeAllNodesExample:
# 3.00┊ 6 ┊
# ┊ ┏━┻━┓ ┊
Expand All @@ -1540,7 +1561,6 @@ def ts():
tables.tree_sequence(), start=1, nodes=np.arange(len(tables.nodes) - 1)
)

# def test_match_sample(self, u, h):
@pytest.mark.parametrize(
("h", "expected_path"),
[
Expand All @@ -1558,20 +1578,99 @@ def ts():
([0, 0, 0, 0, 0, 0], [6] * 6),
],
)
def test_match_sample(self, h, expected_path):
ts = self.ts()
path = check_viterbi(
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
def test_exact_match(self, h, expected_path):
validate_match_all_nodes(self.ts(), h, expected_path)


class TestMultiTreeExample:
# 0.84┊ 7 ┊ 7 ┊
# ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊
# 0.42┊ ┃ ┃ ┊ 6 ┃ ┊
# ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊
# 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊
# ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊
# 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊
# ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊
# 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊
# 0 6 7
@staticmethod
def ts():
nodes = """\
is_sample time
1 0.000000
1 0.000000
1 0.000000
1 0.000000
0 0.041304
0 0.045967
0 0.416719
0 0.838075
"""
edges = """\
left right parent child
0.000000 7.000000 4 1
0.000000 7.000000 4 2
0.000000 6.000000 5 0
0.000000 6.000000 5 4
6.000000 7.000000 6 0
6.000000 7.000000 6 3
0.000000 6.000000 7 3
6.000000 7.000000 7 4
0.000000 6.000000 7 5
6.000000 7.000000 7 6
"""
ts = tskit.load_text(
nodes=io.StringIO(nodes), edges=io.StringIO(edges), strict=False
)
return add_unique_node_mutations(ts, nodes=range(7))

# 0.84┊ 7 ┊ 7 ┊
# ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊
# 0.42┊ ┃ ┃ ┊ 6 ┃ ┊
# ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊
# 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊
# ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊
# 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊
# ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊
# 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊
# 0 6 7

@pytest.mark.parametrize(
("h", "expected_path"),
[
# Just samples
([1, 0, 0, 0, 0, 1, 1], [0] * 7),
([0, 1, 0, 0, 1, 1, 0], [1] * 7),
([0, 0, 1, 0, 1, 1, 0], [2] * 7),
([0, 0, 0, 1, 0, 0, 1], [3] * 7),
# Match root
([0, 0, 0, 0, 0, 0, 0], [7] * 7),
],
)
def test_match_all_nodes(self, h, expected_path):
# print()
# print(self.ts().draw_text())
# with open("tmp.svg", "w") as f:
# f.write(self.ts().draw_svg())
validate_match_all_nodes(self.ts(), h, expected_path)

@pytest.mark.parametrize(
("h", "expected_path"),
[
([1, 0, 0, 0, 0, 1, 1], [0] * 7),
([0, 1, 0, 0, 1, 1, 0], [1] * 7),
([0, 0, 1, 0, 1, 1, 0], [2] * 7),
([0, 0, 0, 1, 0, 0, 1], [3] * 7),
# Switch between each of the samples
([1, 1, 1, 1, 0, 0, 1], [0, 1, 2, 3, 3, 3, 3]),
],
)
def test_match_samples(self, h, expected_path):
ts = self.ts()
path = check_viterbi(ts, h)
nt.assert_array_equal(expected_path, path)
cm = check_forward_matrix(
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
)
print(cm.decode())
bm = check_backward_matrix(
ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=False
)
print(bm.decode())
cm = check_forward_matrix(ts, h)
check_backward_matrix(ts, h, cm)


class TestSimulationExamples:
Expand Down

0 comments on commit 773b66c

Please sign in to comment.