Skip to content

Commit

Permalink
fix: Run rSPR when tree has duplicated nodes (#191)
Browse files Browse the repository at this point in the history
* refactor: Stop analysing trees with duplicate nodes

Signed-off-by: jvfe <[email protected]>

* fix: Update rooted_gene_trees in rspr_exact

Signed-off-by: jvfe <[email protected]>

---------

Signed-off-by: jvfe <[email protected]>
  • Loading branch information
jvfe authored May 7, 2024
1 parent 7fd1ddc commit cb6138f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
53 changes: 40 additions & 13 deletions bin/rspr_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,21 @@ def parse_args(args=None):
)
return parser.parse_args(args)

def check_formatted_tree(tree_string):
"""Check if formatted tree has duplicate nodes"""

pattern = r'([a-zA-Z]+\w{3,}):.*\1'
match = re.search(pattern, tree_string)

return bool(match)

def read_tree(input_path):
with open(input_path, "r") as f:
tree_string = f.read()
formatted = re.sub(r";[^:]+:", ":", tree_string)
return Tree(formatted)
is_duplicated = check_formatted_tree(formatted)

return Tree(formatted), is_duplicated


#####################################################################
Expand All @@ -102,12 +111,27 @@ def read_tree(input_path):
#####################################################################


def root_tree(input_path, output_path):
tre = read_tree(input_path)
def root_tree(input_path, basename, output_path):
tre,is_duplicated = read_tree(input_path)
midpoint = tre.get_midpoint_outgroup()
tre.set_outgroup(midpoint)
if is_duplicated:
outdir = Path(output_path) / "multiple"
Path(outdir).mkdir(exist_ok=True, parents=True)
output_path = outdir / basename
output_path = str(output_path).replace(".tre", ".tre.multiple")
else:
outdir = Path(output_path) / "unique"
Path(outdir).mkdir(exist_ok=True, parents=True)
output_path = outdir / basename

tre.write(outfile=output_path)
return tre.write(), len(tre.get_leaves()), output_path, is_duplicated

def root_reference_tree(input_path, output_path):
tre, _ = read_tree(input_path)
midpoint = tre.get_midpoint_outgroup()
tre.set_outgroup(midpoint)
if not os.path.exists(os.path.dirname(output_path)):
os.makedirs(os.path.dirname(output_path))
tre.write(outfile=output_path)
return tre.write(), len(tre.get_leaves())

Expand Down Expand Up @@ -135,20 +159,23 @@ def root_trees(core_tree, gene_trees_path, output_dir, results, merge_pair=False
rooted_reference_tree = os.path.join(
output_dir, "rooted_reference_tree/core_gene_alignment.tre"
)
refer_content, refer_tree_size = root_tree(reference_tree, rooted_reference_tree)
refer_content, refer_tree_size = root_reference_tree(reference_tree, rooted_reference_tree)

df_gene_trees = pd.read_csv(gene_trees_path)
rooted_gene_trees_path = os.path.join(output_dir, "rooted_gene_trees")
for filename in df_gene_trees["path"]:
basename = Path(filename).name
rooted_gene_tree_path = os.path.join(rooted_gene_trees_path, basename)
gene_content, gene_tree_size = root_tree(filename, rooted_gene_tree_path)
results.loc[basename, "tree_size"] = gene_tree_size
gene_content, gene_tree_size, gene_tree_path, is_duplicated = root_tree(
filename,
basename,
rooted_gene_trees_path)
if not is_duplicated:
results.loc[basename, "tree_size"] = gene_tree_size
if merge_pair:
with open(rooted_gene_tree_path, "w") as f2:
with open(gene_tree_path, "w") as f2:
f2.write(refer_content + "\n" + gene_content)
#'''
return rooted_gene_trees_path
return os.path.join(rooted_gene_trees_path, "unique")


#####################################################################
Expand Down Expand Up @@ -212,7 +239,7 @@ def approx_rspr(
"-length " + str(min_branch_len),
"-support " + str(max_support_threshold),
]

group_size = 10000
cur_count = 0
lst_filename = []
Expand Down Expand Up @@ -498,7 +525,7 @@ def main(args=None):
# Generate group heatmap
group_fig_path = os.path.join(args.OUTPUT_DIR, "group_output.png")
make_group_heatmap(
results,
results,
group_fig_path,
args.MIN_HEATMAP_RSPR_DISTANCE,
args.MAX_HEATMAP_RSPR_DISTANCE
Expand Down
6 changes: 3 additions & 3 deletions bin/rspr_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def fpt_rspr(results_df, min_branch_len=0, max_support_threshold=0.7, gather_clu
"-support " + str(max_support_threshold),
]

trees_path = os.path.join("rooted_gene_trees")
trees_path = os.path.join("rooted_gene_trees/unique")

cluster_file = None
if gather_cluster_info:
Expand Down Expand Up @@ -123,13 +123,13 @@ def fpt_rspr(results_df, min_branch_len=0, max_support_threshold=0.7, gather_clu
continue
elif "Clusters end" in line:
clustering_start = False

if clustering_start:
updated_line = line.replace('(', '').replace(')', '').replace('\n', '')
cluster_nodes = updated_line.split(',')
cluster_nodes = [int(node) for node in cluster_nodes if "X" not in node]
clusters.append(cluster_nodes)

output_lines.append(line)
cluster_file.write(json.dumps(clusters) + '\n')
process.wait()
Expand Down

0 comments on commit cb6138f

Please sign in to comment.