Skip to content

Commit

Permalink
support calculate MHG length
Browse files Browse the repository at this point in the history
  • Loading branch information
yongze-yin committed May 30, 2024
1 parent dcb1a8e commit 2e4b5c3
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "tMHG-Finder"
version = "1.0.0"
version = "1.0.2"
requires-python = ">= 3.7"
dependencies = [
"numpy",
Expand Down
13 changes: 8 additions & 5 deletions src/tmhgf/MHGPartitionMP.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def revComp(seq):
revSeq = ''.join([dic[char] for char in seq[::-1]])
return revSeq

def calculate_mhg_length(mhg):
block_lengths = [(int(b[1][1]) - int(b[1][0]) + 1) for b in mhg]
return round(sum(block_lengths) / len(block_lengths))

def blastToDf(df, threshold, constant = 1.6446838):
"""
Input: A blast file and user preset parameter thresholds.
Expand Down Expand Up @@ -1407,7 +1411,7 @@ def checkPathOverlap(moduleBlockPath,genePath):
def cc_mhg(S_ccIndex_blockInMhg_tuple):
S = S_ccIndex_blockInMhg_tuple[0]
cc_index = S_ccIndex_blockInMhg_tuple[1]
block_in_mhg = S_ccIndex_blockInMhg_tuple[2]
alignment_length_threshold = S_ccIndex_blockInMhg_tuple[2]

valid_mhg = []
tempA = f'tempA_{cc_index}.bed'
Expand Down Expand Up @@ -1495,15 +1499,15 @@ def cc_mhg(S_ccIndex_blockInMhg_tuple):
modules = list(set([tuple(sorted(list(m_graph.nodes))) for m_graph in nodePathToModuleDic.values()]))
for module in modules:
module = [b for b in module if b[1][0] < b[1][1]]
if len(module)>= block_in_mhg:
if len(module)>= 2 and calculate_mhg_length(module) >= alignment_length_threshold:
valid_mhg.append(module)
if os.path.exists(tempA):
os.remove(tempA)
if os.path.exists(tempB):
os.remove(tempB)
return valid_mhg

def mp_mhg(blastDf, block_in_mhg, thread):
def mp_mhg(blastDf, alignment_length_threshold, thread):
if thread > 8:
thread = 8
df_start_time = time.time()
Expand All @@ -1530,8 +1534,7 @@ def mp_mhg(blastDf, block_in_mhg, thread):
logger.info(f"starting traversing the alignment graph and MHG partition")
total_cc_number = nx.number_strongly_connected_components(G)
logger.info(f"Total {total_cc_number} cc are waiting to be visited")
cc_paramater_list = [(G.subgraph(cc),i, block_in_mhg) for i,cc in enumerate(list(nx.strongly_connected_components(G)))]

cc_paramater_list = [(G.subgraph(cc),i, alignment_length_threshold) for i,cc in enumerate(list(nx.strongly_connected_components(G)))]
logger.info(f"MHG partition started using {thread} threads")
mp_t_start = time.time()
p = ProcessingPool(thread)
Expand Down
2 changes: 1 addition & 1 deletion src/tmhgf/tMHGFinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run(genome_dir, temp_genome_dir, kmer_size, thread, mash_tree_path, blastn_d
blastn_out_path = BlastnProcess.blastn_next(ready_MHG_dict, blastn_dir, temp_genome_dir, distance_matrix_dict, thread, hash_code_prefix, node_hash_dict)
df, check_dict = MHGPartitionMP.parseBlastXML(blastn_out_path, alignment_length_threshold)
df = MHGPartitionMP.trim_fully_contain(df, check_dict)
mhg_list = MHGPartitionMP.mp_mhg(df, 2, thread)
mhg_list = MHGPartitionMP.mp_mhg(df, alignment_length_threshold, thread)
if '|' not in internal_node_taxa:
# Case 1: two children nodes are both leaf nodes
pan_mhg_list = ProcessMHG.pangenome_leaf(mhg_list, accDic)
Expand Down

0 comments on commit 2e4b5c3

Please sign in to comment.