From 2e4b5c324f46e46b21a25115a1c7541537e0e64a Mon Sep 17 00:00:00 2001 From: yongze-yin Date: Thu, 30 May 2024 18:09:58 -0500 Subject: [PATCH] support calculate MHG length --- pyproject.toml | 2 +- src/tmhgf/MHGPartitionMP.py | 13 ++++++++----- src/tmhgf/tMHGFinder.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 84dfc2c..333a5a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "tMHG-Finder" -version = "1.0.0" +version = "1.0.2" requires-python = ">= 3.7" dependencies = [ "numpy", diff --git a/src/tmhgf/MHGPartitionMP.py b/src/tmhgf/MHGPartitionMP.py index 62a6f87..5a344cb 100644 --- a/src/tmhgf/MHGPartitionMP.py +++ b/src/tmhgf/MHGPartitionMP.py @@ -95,6 +95,10 @@ def revComp(seq): revSeq = ''.join([dic[char] for char in seq[::-1]]) return revSeq +def calculate_mhg_length(mhg): + block_lengths = [(int(b[1][1]) - int(b[1][0]) + 1) for b in mhg] + return round(sum(block_lengths) / len(block_lengths)) + def blastToDf(df, threshold, constant = 1.6446838): """ Input: A blast file and user preset parameter thresholds. @@ -1407,7 +1411,7 @@ def checkPathOverlap(moduleBlockPath,genePath): def cc_mhg(S_ccIndex_blockInMhg_tuple): S = S_ccIndex_blockInMhg_tuple[0] cc_index = S_ccIndex_blockInMhg_tuple[1] - block_in_mhg = S_ccIndex_blockInMhg_tuple[2] + alignment_length_threshold = S_ccIndex_blockInMhg_tuple[2] valid_mhg = [] tempA = f'tempA_{cc_index}.bed' @@ -1495,7 +1499,7 @@ def cc_mhg(S_ccIndex_blockInMhg_tuple): modules = list(set([tuple(sorted(list(m_graph.nodes))) for m_graph in nodePathToModuleDic.values()])) for module in modules: module = [b for b in module if b[1][0] < b[1][1]] - if len(module)>= block_in_mhg: + if len(module)>= 2 and calculate_mhg_length(module) >= alignment_length_threshold: valid_mhg.append(module) if os.path.exists(tempA): os.remove(tempA) @@ -1503,7 +1507,7 @@ def cc_mhg(S_ccIndex_blockInMhg_tuple): os.remove(tempB) return valid_mhg -def mp_mhg(blastDf, block_in_mhg, thread): +def mp_mhg(blastDf, alignment_length_threshold, thread): if thread > 8: thread = 8 df_start_time = time.time() @@ -1530,8 +1534,7 @@ def mp_mhg(blastDf, block_in_mhg, thread): logger.info(f"starting traversing the alignment graph and MHG partition") total_cc_number = nx.number_strongly_connected_components(G) logger.info(f"Total {total_cc_number} cc are waiting to be visited") - cc_paramater_list = [(G.subgraph(cc),i, block_in_mhg) for i,cc in enumerate(list(nx.strongly_connected_components(G)))] - + cc_paramater_list = [(G.subgraph(cc),i, alignment_length_threshold) for i,cc in enumerate(list(nx.strongly_connected_components(G)))] logger.info(f"MHG partition started using {thread} threads") mp_t_start = time.time() p = ProcessingPool(thread) diff --git a/src/tmhgf/tMHGFinder.py b/src/tmhgf/tMHGFinder.py index 97d60a5..117400a 100644 --- a/src/tmhgf/tMHGFinder.py +++ b/src/tmhgf/tMHGFinder.py @@ -40,7 +40,7 @@ def run(genome_dir, temp_genome_dir, kmer_size, thread, mash_tree_path, blastn_d blastn_out_path = BlastnProcess.blastn_next(ready_MHG_dict, blastn_dir, temp_genome_dir, distance_matrix_dict, thread, hash_code_prefix, node_hash_dict) df, check_dict = MHGPartitionMP.parseBlastXML(blastn_out_path, alignment_length_threshold) df = MHGPartitionMP.trim_fully_contain(df, check_dict) - mhg_list = MHGPartitionMP.mp_mhg(df, 2, thread) + mhg_list = MHGPartitionMP.mp_mhg(df, alignment_length_threshold, thread) if '|' not in internal_node_taxa: # Case 1: two children nodes are both leaf nodes pan_mhg_list = ProcessMHG.pangenome_leaf(mhg_list, accDic)