From e99916ee5489319335aae337e7567a79a42da226 Mon Sep 17 00:00:00 2001 From: jonperdomo Date: Wed, 5 Jun 2024 16:19:37 -0400 Subject: [PATCH] Update scikit version and implement HDBSCAN clustering --- environment.yml | 4 +- python/sv_merger.py | 350 ++++++++++++++++++-------------------------- 2 files changed, 143 insertions(+), 211 deletions(-) diff --git a/environment.yml b/environment.yml index 8820b657..3632b640 100644 --- a/environment.yml +++ b/environment.yml @@ -12,5 +12,7 @@ dependencies: - pytest - plotly - pandas - - scikit-learn + - scikit-learn>=1.3 - joblib + +# conda env create --name contextsv --file environment.yml diff --git a/python/sv_merger.py b/python/sv_merger.py index d4a43e73..adddf8a9 100644 --- a/python/sv_merger.py +++ b/python/sv_merger.py @@ -18,10 +18,13 @@ import matplotlib.pyplot as plt # For plotting merge behavior # DBSCAN clustering algorithm -from sklearn.cluster import DBSCAN +# from sklearn.cluster import DBSCAN -# Agglomerative clustering algorithm -from sklearn.cluster import AgglomerativeClustering +# OPTICS clustering algorithm +# from sklearn.cluster import OPTICS + +# HDBSCAN clustering algorithm +from sklearn.cluster import HDBSCAN def plot_dbscan(breakpoints, chosen_breakpoints, filename='dbscan_clustering.png'): @@ -74,19 +77,11 @@ def update_support(record, cluster_size): info_fields = info.split(';') # Loop and update the SUPPORT field, while creating a new INFO string - support = 0 updated_info = '' for field in info_fields: if field.startswith('SUPPORT='): - support = int(field.split('=')[1]) - # Set the SUPPORT field to the cluster size - # updated_info = f'SUPPORT={cluster_size};' updated_info += f'SUPPORT={cluster_size};' - - # # Increment the SUPPORT field by the cluster size - # support += cluster_size - # updated_info += f'SUPPORT={support};' else: updated_info += field + ';' # Append the field to the updated INFO @@ -95,6 +90,116 @@ def update_support(record, cluster_size): return record +def cluster_breakpoints(vcf_df, sv_type, min_samples=2): + """ + Cluster SV breakpoints using HDBSCAN. + """ + # Set up the output DataFrame + merged_records = pd.DataFrame(columns=['INDEX', 'CHROM', 'POS', 'INFO']) + + # Format the SV breakpoints + breakpoints = None + if sv_type == 'DEL': + sv_start = vcf_df['POS'].values + sv_end = vcf_df['INFO'].str.extract(r'END=(\d+)', expand=False).astype(np.int32) + + # Format the deletion breakpoints + breakpoints = np.column_stack((sv_start, sv_end)) + + elif sv_type == 'INS' or sv_type == 'DUP': + sv_start = vcf_df['POS'].values + sv_len = vcf_df['INFO'].str.extract(r'SVLEN=(-?\d+)', expand=False).astype(np.int32) + sv_end = sv_start + sv_len - 1 + + # Format the insertion and duplication breakpoints + breakpoints = np.column_stack((sv_start, sv_end)) + else: + logging.error(f"Invalid SV type: {sv_type}") + return + + # Get the combined SV read and clipped base support + sv_support = vcf_df['INFO'].str.extract(r'SUPPORT=(\d+)', expand=False).astype(np.int32) + sv_clipped_base_support = vcf_df['INFO'].str.extract(r'CLIPSUP=(\d+)', expand=False).astype(np.int32) + sv_support = sv_support + sv_clipped_base_support + + # Get the HMM likelihood scores + hmm_scores = vcf_df['INFO'].str.extract(r'HMM=(\d+)', expand=False).astype(np.float32) + + # Cluster SV breakpoints using HDBSCAN + cluster_labels = [] + dbscan = HDBSCAN(min_cluster_size=min_samples) + if len(breakpoints) > 0: + logging.info(f"Clustering deletion breakpoints using HDSCAN with minimum cluster size={min_samples}...") + cluster_labels = dbscan.fit_predict(breakpoints) + + logging.info(f"Label counts: {len(np.unique(cluster_labels))}") + + # Merge SVs with the same label + unique_labels = np.unique(cluster_labels) + for label in unique_labels: + + # Skip label -1 (outliers) + if label == -1: + # Print the positions if any are within a certain range + pos_min = 180915940 + pos_max = 180950356 + idx = cluster_labels == label + pos_values = breakpoints[idx][:, 0] + if (np.any(pos_values >= pos_min) and np.any(pos_values <= pos_max)): + # Print all within range + pos_within_range = pos_values[(pos_values >= pos_min) & (pos_values <= pos_max)] + logging.info(f"Outlier deletion positions: {pos_within_range}") + # logging.info(f"Outlier deletion positions: {pos_values}") + + continue + + # Get the indices of SVs with the same label + idx = cluster_labels == label + + # Use the SV with the lowest log likelihood score, if available + # (values are not all the same) + max_score_idx = None + cluster_hmm_scores = hmm_scores[idx] + cluster_depth_scores = sv_support[idx] + if len(np.unique(cluster_hmm_scores)) > 1: + max_score_idx = np.argmin(cluster_hmm_scores) + + # Use the SV with the highest read support if the log likelihood + # scores are all the same + elif len(np.unique(cluster_depth_scores)) > 1: + max_score_idx = np.argmax(cluster_depth_scores) + + # Use the first SV in the cluster if the depth scores are all the + # same + else: + max_score_idx = 0 + + # Get the VCF record with the highest depth score + max_del_record = vcf_df.iloc[idx, :].iloc[max_score_idx, :] + + # Get the number of SVs in this cluster + cluster_size = np.sum(idx) + # logging.info("DEL Cluster size: %s", cluster_size) + + # Update the SUPPORT field in the INFO column + max_del_record = update_support(max_del_record, cluster_size) + + # Get all position values in the cluster + pos_values = breakpoints[idx][:, 0] + + # If the POS value is a certain value, plot the support + pos_min = 180915940 + pos_max = 180950356 + if (np.any(pos_values >= pos_min) and np.any(pos_values <= pos_max)) or cluster_size > 1000: + logging.info(f"Cluster size: {cluster_size}") + logging.info(f"Pos values: {pos_values}") + + # Append the chosen record to the dataframe of records that will + # form the merged VCF file + merged_records.loc[merged_records.shape[0]] = max_del_record + + return merged_records + def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.merged'): """ @@ -120,7 +225,6 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer logging.info(f"Total number of records: {vcf_df.shape[0]}") # Store a dataframe of records that will form the merged VCF file - # merged_records = pd.DataFrame(columns=['CHROM', 'POS', 'INFO']) merged_records = pd.DataFrame(columns=['INDEX', 'CHROM', 'POS', 'INFO']) # Create a set with each chromosome in the VCF file @@ -137,208 +241,38 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer for chromosome in chromosomes: logging.info(f"Clustering SVs on chromosome {chromosome}...") - # Get the chromosome deletion, insertion, and duplication breakpoints + # Cluster deletions + logging.info(f"Clustering deletions on chromosome {chromosome}...") chr_del_df = vcf_df[(vcf_df['CHROM'] == chromosome) & (vcf_df['INFO'].str.contains('SVTYPE=DEL'))] - chr_ins_dup_df = vcf_df[(vcf_df['CHROM'] == chromosome) & ((vcf_df['INFO'].str.contains('SVTYPE=INS')) | (vcf_df['INFO'].str.contains('SVTYPE=DUP')))] - - # Get the deletion start and end positions - chr_del_start = chr_del_df['POS'].values - chr_del_end = chr_del_df['INFO'].str.extract(r'END=(\d+)', expand=False).astype(np.int32) - - # Get the insertion and duplication start and end positions - chr_ins_dup_start = chr_ins_dup_df['POS'].values - chr_ins_dup_len = chr_ins_dup_df['INFO'].str.extract(r'SVLEN=(-?\d+)', expand=False).astype(np.int32) - chr_ins_dup_end = chr_ins_dup_start + chr_ins_dup_len - 1 - - # Format the deletion breakpoints - chr_del_breakpoints = np.column_stack((chr_del_start, chr_del_end)) - logging.info("Number of deletion breakpoints: " + str(len(chr_del_breakpoints))) + del_records = cluster_breakpoints(chr_del_df, 'DEL', min_samples=min_samples) - # Format the insertion and duplication breakpoints - chr_ins_dup_breakpoints = np.column_stack((chr_ins_dup_start, chr_ins_dup_end)) - logging.info("Number of insertion and duplication breakpoints: " + str(len(chr_ins_dup_breakpoints))) - - # Get the SV depth and clipped base evidence scores for deletions - # chr_del_depth_scores = chr_del_df['INFO'].str.extract(r'DP=(\d+)', - # expand=False).astype(np.int32) - chr_del_support = chr_del_df['INFO'].str.extract(r'SUPPORT=(\d+)', expand=False).astype(np.int32) - chr_del_clipped_bases = chr_del_df['INFO'].str.extract(r'CLIPSUP=(\d+)', expand=False).astype(np.int32) - chr_del_depth_scores = chr_del_support + chr_del_clipped_bases - - # Get the SV depth and clipped base evidence scores for insertions - # and duplications - # chr_ins_dup_depth_scores = - # chr_ins_dup_df['INFO'].str.extract(r'DP=(\d+)', - # expand=False).astype(np.int32) - chr_ins_dup_support = chr_ins_dup_df['INFO'].str.extract(r'SUPPORT=(\d+)', expand=False).astype(np.int32) - chr_ins_dup_clipped_bases = chr_ins_dup_df['INFO'].str.extract(r'CLIPSUP=(\d+)', expand=False).astype(np.int32) - chr_ins_dup_depth_scores = chr_ins_dup_support + chr_ins_dup_clipped_bases - - # Cluster SV breakpoints using DBSCAN - dbscan = DBSCAN(eps=eps, min_samples=min_samples) - - # Cluster deletion breakpoints - if len(chr_del_breakpoints) > 0: - logging.info(f"Clustering deletion breakpoints using DBSCAN with eps={eps} and min_samples={min_samples}...") - del_labels = dbscan.fit_predict(chr_del_breakpoints) - logging.info(f"Deletion label counts: {len(np.unique(del_labels))}") - else: - del_labels = [] - - # Cluster insertion and duplication breakpoints together since - # insertions are a subset of duplications - if len(chr_ins_dup_breakpoints) > 0: - logging.info(f"Clustering insertion and duplication breakpoints using DBSCAN with eps={eps} and min_samples={min_samples}...") - ins_dup_labels = dbscan.fit_predict(chr_ins_dup_breakpoints) - logging.info(f"Insertion and duplication label counts: {len(np.unique(ins_dup_labels))}") - else: - ins_dup_labels = [] - - - # Get the unique deletion labels for the chromosome - unique_del_labels = np.unique(del_labels) - - # Get the unique insertion and duplication labels for the chromosome - unique_ins_dup_labels = np.unique(ins_dup_labels) - - # Merge deletions with the same label - del_count = 0 - for label in unique_del_labels: - - # Skip label -1 (outliers) - if label == -1: - # Print the positions if any are within a certain range - pos_min = 180915940 - pos_max = 180950356 - idx = del_labels == label - pos_values = chr_del_breakpoints[idx][:, 0] - if (np.any(pos_values >= pos_min) and np.any(pos_values <= pos_max)): - # Print all within range - pos_within_range = pos_values[(pos_values >= pos_min) & (pos_values <= pos_max)] - logging.info(f"Outlier deletion positions: {pos_within_range}") - # logging.info(f"Outlier deletion positions: {pos_values}") - - continue - - # Get the indices of SVs with the same label - idx = del_labels == label - - # Get the SV depth scores with the same label - depth_scores = chr_del_depth_scores[idx] - - # Get the index of the SV with the highest depth score - max_depth_score_idx = np.argmax(depth_scores) - - # Get the VCF record with the highest depth score - max_del_record = chr_del_df.iloc[idx, :].iloc[max_depth_score_idx, :] - - # Get the number of SVs in this cluster - cluster_size = np.sum(idx) - # logging.info("DEL Cluster size: %s", cluster_size) - - # Update the SUPPORT field in the INFO column - max_del_record = update_support(max_del_record, cluster_size) + # Cluster insertions and duplications + logging.info(f"Clustering insertions and duplications on chromosome {chromosome}...") + chr_ins_dup_df = vcf_df[(vcf_df['CHROM'] == chromosome) & ((vcf_df['INFO'].str.contains('SVTYPE=INS')) | (vcf_df['INFO'].str.contains('SVTYPE=DUP')))] + ins_dup_records = cluster_breakpoints(chr_ins_dup_df, 'INS', min_samples=min_samples) - # Get all position values in the cluster - pos_values = chr_del_breakpoints[idx][:, 0] + # Summarize the number of deletions and insertions/duplications + del_count = del_records.shape[0] + ins_dup_count = ins_dup_records.shape[0] + records_processed += del_count + ins_dup_count + logging.info(f"Chromosome {chromosome} - {del_count} deletions, {ins_dup_count} insertions, and duplications merged.") - # If the POS value is a certain value, plot the support - # target_pos = 180949294 - pos_min = 180915940 - pos_max = 180950356 - # if max_del_record['POS'] == target_pos: - # if max_del_record['POS'] >= pos_min and max_del_record['POS'] <= - # pos_max: - # If any position in the cluster is within the range, log the - # information (or if the cluster size is greater than 10) - if (np.any(pos_values >= pos_min) and np.any(pos_values <= pos_max)) or cluster_size > 1000: - logging.info(f"Cluster size: {cluster_size}") - logging.info(f"Pos values: {pos_values}") - # logging.info(f"FOUND POS {target_pos} - Cluster size: - # {cluster_size}") - # logging.info(f"FOUND POS {max_del_record['POS']} - Cluster size: {cluster_size}") - # logging.info(f"INFO: {max_del_record['INFO']}") - - # Append the chosen record to the dataframe of records that will - # form the merged VCF file - merged_records.loc[merged_records.shape[0]] = max_del_record - - # Plot the DBSCAN clustering behavior if there are 10 < X < 20 SVs with the same label - plot_enabled = False - if plot_enabled: - logging.info(f"Plotting DBSCAN clustering behavior for label {label}...") - if len(chr_del_breakpoints[idx]) > 10 and len(chr_del_breakpoints[idx]) < 20 and num_plots < max_plots: - - # Increment the number of plots - num_plots += 1 - - # Convert the max depth score index (index within labels) to the index within the original deletion DataFrame - chosen_idx = np.where(idx)[0][max_depth_score_idx] - chosen_breakpoints = chr_del_breakpoints[chosen_idx] - plot_dbscan(chr_del_breakpoints[idx], chosen_breakpoints, filename=f"dbscan_clustering_{num_plots}.png") - - # Return if the number of plots is reached - if num_plots == max_plots: - return - - # Merge insertions and duplications with the same label - ins_dup_count = 0 - for label in unique_ins_dup_labels: - - # Skip label -1 (outliers) - if label == -1: - continue - - # Get the indices of SVs with the same label. IDX is a boolean - # array, where True indicates the SV has the same label - idx = ins_dup_labels == label - - # Get the SV depth scores with the same label - depth_scores = chr_ins_dup_depth_scores[idx] - - # Get the index of the SV with the highest depth score - max_depth_score_idx = np.argmax(depth_scores) - - # Get the VCF record with the highest depth score - max_ins_dup_record = chr_ins_dup_df.iloc[idx, :].iloc[max_depth_score_idx, :] - - # Get the number of SVs in this cluster - cluster_size = np.sum(idx) - # logging.info("DUP Cluster size: %s", cluster_size) - - # Update the SUPPORT field in the INFO column - max_ins_dup_record = update_support(max_ins_dup_record, cluster_size) - - # Append the chosen record to the dataframe of records that will - # form the merged VCF file - merged_records.loc[merged_records.shape[0]] = max_ins_dup_record + # Append the deletion and insertion/duplication records to the merged + # records DataFrame + merged_records = pd.concat([merged_records, del_records, ins_dup_records], ignore_index=True) - logging.info(f"Chromosome {chromosome} - {del_count} deletions, {ins_dup_count} insertions, and duplications merged.") - + # logging.info(f"Chromosome {chromosome} - {del_count} deletions, {ins_dup_count} insertions, and duplications merged.") + current_chromosome += 1 logging.info(f"Processed {current_chromosome} of {chromosome_count} chromosomes.") - - records_processed += chr_del_breakpoints.shape[0] + chr_ins_dup_breakpoints.shape[0] logging.info(f"Processed {records_processed} records of {vcf_df.shape[0]} total records.") + # Free up memory del vcf_df del chr_del_df del chr_ins_dup_df - del chr_del_start - del chr_del_end - del chr_ins_dup_start - del chr_ins_dup_len - del chr_ins_dup_end - del chr_del_breakpoints - del chr_ins_dup_breakpoints - del chr_del_depth_scores - del chr_ins_dup_depth_scores - del del_labels - del ins_dup_labels - del unique_del_labels - del unique_ins_dup_labels # Open a new VCF file for writing logging.info("Writing merged VCF file...") @@ -365,6 +299,7 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer dtype={'CHROM': str, 'POS': np.int64, 'ID': str, 'REF': str, 'ALT': str, 'QUAL': str, \ 'FILTER': str, 'INFO': str, 'FORMAT': str, 'SAMPLE': str}, \ chunksize=1000): + # Add an INDEX column to the chunk chunk.insert(0, 'INDEX', range(index_start, index_start + chunk.shape[0])) index_start += chunk.shape[0] @@ -373,20 +308,10 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer # (chunk) but update the INFO field with the merged INFO field. # This is done by dropping the INFO column from the chunk so that # the INFO column from the merged_records dataframe is used. - # matching_records = pd.merge(chunk.drop(columns=['INFO']), merged_records, on=['INDEX'], how='inner') matching_records = pd.merge(chunk.drop(columns=['INFO']), merged_records[['INDEX', 'INFO']], on=['INDEX'], how='inner') - - # Print the columns of the matching records matching_records = matching_records.drop_duplicates(subset=['INDEX']) # Drop duplicate records matching_records = matching_records.drop(columns=['INDEX']) # Drop the INDEX column - - # Get the matching records from the chunk by merging on CHROM, POS, - # and INFO, but only keep the records from the chunk since they - # contain the full VCF record - # matching_records = pd.merge(chunk, merged_records, on=['CHROM', 'POS', 'INFO'], how='inner') - # matching_records = matching_records.drop_duplicates(subset=['CHROM', 'POS', 'INFO']) # Drop duplicate records - # Remove the matching records from the merged records dataframe merged_records = merged_records[~merged_records.isin(matching_records)].dropna() @@ -412,6 +337,11 @@ def sv_merger(vcf_file_path, mode='dbscan', eps=100, min_samples=2, suffix='.mer # Get the VCF file path from the command line vcf_file_path = sys.argv[1] + # Check if the file exists + if not os.path.exists(vcf_file_path): + logging.error(f"Error: {vcf_file_path} not found.") + sys.exit(1) + # Get the epsilon value from the command line if len(sys.argv) > 2: eps = int(sys.argv[2])